mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 13:50:35 +08:00
more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0dce883241
commit
771cf8a328
@@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: KNOWN
|
||||
images_per_seq_options: []
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
|
||||
@@ -8,6 +8,11 @@
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
@@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
|
||||
counter[i // divisor] += 1
|
||||
|
||||
return counter
|
||||
|
||||
|
||||
class TestRandomSampling(unittest.TestCase):
|
||||
def test_double_pool_batch_sampler(self):
|
||||
unknown_idxs = [2, 3, 4, 5, 8]
|
||||
known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
||||
for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
|
||||
with self.subTest(f"{replacement}, {num_batches}"):
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=unknown_idxs,
|
||||
rest_indices=known_idxs,
|
||||
batch_size=4,
|
||||
replacement=replacement,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
for _ in range(6):
|
||||
epoch = list(sampler)
|
||||
self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
|
||||
for batch in epoch:
|
||||
self.assertEqual(len(batch), 4)
|
||||
self.assertIn(batch[0], unknown_idxs)
|
||||
for i in batch[1:]:
|
||||
self.assertIn(i, known_idxs)
|
||||
if not replacement and 4 != num_batches:
|
||||
self.assertEqual(
|
||||
{batch[0] for batch in epoch}, set(unknown_idxs)
|
||||
)
|
||||
|
||||
@@ -10,11 +10,12 @@ import unittest
|
||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
||||
BlenderDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
||||
LlffDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
|
||||
for i in batch[1:]:
|
||||
self.assertEqual(dataset_map.test[i].frame_type, "known")
|
||||
|
||||
def test_loaders(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
|
||||
dataset_args.object_name = "lego"
|
||||
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["known"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.val:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.test:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
|
||||
@@ -8,6 +8,7 @@ import os
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
@@ -21,6 +22,7 @@ DEBUG: bool = False
|
||||
class TestDataSource(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _test_omegaconf_generic_failure(self):
|
||||
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||
@@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase):
|
||||
if DEBUG:
|
||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
||||
|
||||
def test_default(self):
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
return
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.n_frames_per_sequence = -1
|
||||
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 81)
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["test_known"])
|
||||
break
|
||||
|
||||
@@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 64
|
||||
expand_args_fields(JsonIndexDataset)
|
||||
self.dataset = JsonIndexDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
|
||||
Reference in New Issue
Block a user