From 771cf8a328e48da9d369385240ef14f38409c2ee Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 6 Jul 2022 07:13:41 -0700 Subject: [PATCH] more padding options in Dataloader Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447 --- .../configs/repro_base.yaml | 16 +- .../configs/repro_multiseq_base.yaml | 4 +- .../configs/repro_singleseq_base.yaml | 6 +- .../configs/repro_singleseq_wce_base.yaml | 4 +- .../implicitron_trainer/tests/experiment.yaml | 11 +- .../tests/test_experiment.py | 2 +- .../dataset/data_loader_map_provider.py | 365 ++++++++++++++++-- pytorch3d/implicitron/dataset/dataset_base.py | 6 + pytorch3d/implicitron/dataset/utils.py | 10 +- pytorch3d/implicitron/dataset/visualize.py | 2 +- pytorch3d/implicitron/eval_demo.py | 4 +- tests/implicitron/data/data_source.yaml | 11 +- tests/implicitron/test_batch_sampler.py | 32 ++ tests/implicitron/test_data_llff.py | 23 +- tests/implicitron/test_data_source.py | 22 ++ tests/implicitron/test_evaluation.py | 1 + 16 files changed, 449 insertions(+), 70 deletions(-) diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index a3d5859f..8c2c4066 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -6,22 +6,12 @@ architecture: generic visualize_interval: 0 visdom_port: 8097 data_source_args: + data_loader_map_provider_class_type: SequenceDataLoaderMapProvider dataset_map_provider_class_type: JsonIndexDatasetMapProvider data_loader_map_provider_SequenceDataLoaderMapProvider_args: - batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 + dataset_length_train: 1000 + dataset_length_val: 1 num_workers: 8 - images_per_seq_options: - - 2 - - 3 - - 4 - - 5 - - 6 - - 7 - - 8 - - 9 - - 10 dataset_map_provider_JsonIndexDatasetMapProvider_args: dataset_root: ${oc.env:CO3D_DATASET_ROOT} n_frames_per_sequence: -1 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml index 31e9d083..b1788c73 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -4,8 +4,8 @@ defaults: data_source_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 + dataset_length_train: 1000 + dataset_length_val: 1 num_workers: 8 images_per_seq_options: - 2 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml index bde1abb5..977d5dd9 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -4,11 +4,9 @@ defaults: data_source_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 - dataset_len: 1000 - dataset_len_val: 1 + dataset_length_train: 1000 + dataset_length_val: 1 num_workers: 8 - images_per_seq_options: - - 2 dataset_map_provider_JsonIndexDatasetMapProvider_args: assert_single_seq: true n_frames_per_sequence: -1 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml index b714490b..4c11aba9 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml @@ -4,8 +4,8 @@ defaults: data_source_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 - dataset_len: 1000 - dataset_len_val: 1 + dataset_length_train: 1000 + dataset_length_val: 1 num_workers: 8 images_per_seq_options: - 2 diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 3eefa12e..86e9e1dd 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -345,10 +345,13 @@ data_source_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 - dataset_len: 1000 - dataset_len_val: 1 - images_per_seq_options: - - 2 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 + train_conditioning_type: SAME + val_conditioning_type: SAME + test_conditioning_type: KNOWN + images_per_seq_options: [] sample_consecutive_frames: false consecutive_frames_max_gap: 0 consecutive_frames_max_gap_seconds: 0.1 diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index ee41ce9f..47348a59 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -55,7 +55,7 @@ class TestExperiment(unittest.TestCase): dataset_args.test_restrict_sequence_id = 0 dataset_args.dataset_root = "manifold://co3d/tree/extracted" dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5 - dataloader_args.dataset_len = 1 + dataloader_args.dataset_length_train = 1 cfg.solver_args.max_epochs = 2 device = torch.device("cuda:0") diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py index 9ff53bdc..0b159c66 100644 --- a/pytorch3d/implicitron/dataset/data_loader_map_provider.py +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -5,14 +5,23 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Tuple +from enum import Enum +from typing import Iterator, List, Optional, Tuple import torch from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from torch.utils.data import ( + BatchSampler, + ChainDataset, + DataLoader, + RandomSampler, + Sampler, +) from .dataset_base import DatasetBase, FrameData from .dataset_map_provider import DatasetMap from .scene_batch_sampler import SceneBatchSampler +from .utils import is_known_frame_scalar @dataclass @@ -27,13 +36,11 @@ class DataLoaderMap: test: a data loader for final evaluation """ - train: Optional[torch.utils.data.DataLoader[FrameData]] - val: Optional[torch.utils.data.DataLoader[FrameData]] - test: Optional[torch.utils.data.DataLoader[FrameData]] + train: Optional[DataLoader[FrameData]] + val: Optional[DataLoader[FrameData]] + test: Optional[DataLoader[FrameData]] - def __getitem__( - self, split: str - ) -> Optional[torch.utils.data.DataLoader[FrameData]]: + def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]: """ Get one of the data loaders by key (name of data split) """ @@ -54,17 +61,155 @@ class DataLoaderMapProviderBase(ReplaceableBase): raise NotImplementedError() +class DoublePoolBatchSampler(Sampler[List[int]]): + """ + Batch sampler for making random batches of a single frame + from one list and a number of known frames from another list. + """ + + def __init__( + self, + first_indices: List[int], + rest_indices: List[int], + batch_size: int, + replacement: bool, + num_batches: Optional[int] = None, + ) -> None: + """ + Args: + first_indices: indexes of dataset items to use as the first element + of each batch. + rest_indices: indexes of dataset items to use as the subsequent + elements of each batch. Not used if batch_size==1. + batch_size: The common size of any batch. + replacement: Whether the sampling of first items is with replacement. + num_batches: The number of batches in an epoch. If 0 or None, + one epoch is the length of `first_indices`. + """ + self.first_indices = first_indices + self.rest_indices = rest_indices + self.batch_size = batch_size + self.replacement = replacement + self.num_batches = None if num_batches == 0 else num_batches + + if batch_size - 1 > len(rest_indices): + raise ValueError( + f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}" + ) + + # copied from RandomSampler + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + self.generator = torch.Generator() + self.generator.manual_seed(seed) + + def __len__(self) -> int: + if self.num_batches is not None: + return self.num_batches + return len(self.first_indices) + + def __iter__(self) -> Iterator[List[int]]: + num_batches = self.num_batches + if self.replacement: + i_first = torch.randint( + len(self.first_indices), + size=(len(self),), + generator=self.generator, + ) + elif num_batches is not None: + n_copies = 1 + (num_batches - 1) // len(self.first_indices) + raw_indices = [ + torch.randperm(len(self.first_indices), generator=self.generator) + for _ in range(n_copies) + ] + i_first = torch.concat(raw_indices)[:num_batches] + else: + i_first = torch.randperm(len(self.first_indices), generator=self.generator) + first_indices = [self.first_indices[i] for i in i_first] + + if self.batch_size == 1: + for first_index in first_indices: + yield [first_index] + return + + for first_index in first_indices: + # Consider using this class in a program which sets the seed. This use + # of randperm means that rerunning with a higher batch_size + # results in batches whose first elements as the first run. + i_rest = torch.randperm( + len(self.rest_indices), + generator=self.generator, + )[: self.batch_size - 1] + yield [first_index] + [self.rest_indices[i] for i in i_rest] + + +class BatchConditioningType(Enum): + """ + Ways to add conditioning frames for the val and test batches. + + SAME: Use the corresponding dataset for all elements of val batches + without regard to frame type. + TRAIN: Use the corresponding dataset for the first element of each + batch, and the training dataset for the extra conditioning + elements. No regard to frame type. + KNOWN: Use frames from the corresponding dataset but separate them + according to their frame_type. Each batch will contain one UNSEEN + frame followed by many KNOWN frames. + """ + + SAME = "same" + TRAIN = "train" + KNOWN = "known" + + @registry.register class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): """ - The default implementation of DataLoaderMapProviderBase. + Default implementation of DataLoaderMapProviderBase. + + If a dataset returns batches from get_eval_batches(), then + they will be what the corresponding dataloader returns, + independently of any of the fields on this class. + + If conditioning is not required, then the batch size should + be set as 1, and most of the fields do not matter. + + If conditioning is required, each batch will contain one main + frame first to predict and the, rest of the elements are for + conditioning. + + If images_per_seq_options is left empty, the conditioning + frames are picked according to the conditioning type given. + This does not have regard to the order of frames in a + scene, or which frames belong to what scene. + + If images_per_seq_options is given, then the conditioning types + must be SAME and the remaining fields are used. Members: batch_size: The size of the batch of the data loader. - num_workers: Number data-loading threads. - dataset_len: The number of batches in a training epoch. - dataset_len_val: The number of batches in a validation epoch. - images_per_seq_options: Possible numbers of images sampled per sequence. + num_workers: Number of data-loading threads in each data loader. + dataset_length_train: The number of batches in a training epoch. Or 0 to mean + an epoch is the length of the training set. + dataset_length_val: The number of batches in a validation epoch. Or 0 to mean + an epoch is the length of the validation set. + dataset_length_test: The number of batches in a testing epoch. Or 0 to mean + an epoch is the length of the test set. + train_conditioning_type: Whether the train data loader should use + only known frames for conditioning. + Only used if batch_size>1 and train dataset is + present and does not return eval_batches. + val_conditioning_type: Whether the val data loader should use + training frames or known frames for conditioning. + Only used if batch_size>1 and val dataset is + present and does not return eval_batches. + test_conditioning_type: Whether the test data loader should use + training frames or known frames for conditioning. + Only used if batch_size>1 and test dataset is + present and does not return eval_batches. + images_per_seq_options: Possible numbers of frames sampled per sequence in a batch. + If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial + value. Empty (the default) means that we are not careful about which frames + come from which scene. sample_consecutive_frames: if True, will sample a contiguous interval of frames in the sequence. It first sorts the frames by timestimps when available, otherwise by frame numbers, finds the connected segments within the sequence @@ -84,9 +229,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): batch_size: int = 1 num_workers: int = 0 - dataset_len: int = 1000 - dataset_len_val: int = 1 - images_per_seq_options: Tuple[int, ...] = (2,) + dataset_length_train: int = 0 + dataset_length_val: int = 0 + dataset_length_test: int = 0 + train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME + val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME + test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN + images_per_seq_options: Tuple[int, ...] = () sample_consecutive_frames: bool = False consecutive_frames_max_gap: int = 0 consecutive_frames_max_gap_seconds: float = 0.1 @@ -95,17 +244,73 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): """ Returns a collection of data loaders for a given collection of datasets. """ + return DataLoaderMap( + train=self._make_data_loader( + datasets.train, + self.dataset_length_train, + datasets.train, + self.train_conditioning_type, + ), + val=self._make_data_loader( + datasets.val, + self.dataset_length_val, + datasets.train, + self.val_conditioning_type, + ), + test=self._make_data_loader( + datasets.test, + self.dataset_length_test, + datasets.train, + self.test_conditioning_type, + ), + ) + + def _make_data_loader( + self, + dataset: Optional[DatasetBase], + num_batches: int, + train_dataset: Optional[DatasetBase], + conditioning_type: BatchConditioningType, + ) -> Optional[DataLoader[FrameData]]: + """ + Returns the dataloader for a dataset. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + train_dataset: the training dataset, used if conditioning_type==TRAIN + conditioning_type: source for padding of batches + """ + if dataset is None: + return None data_loader_kwargs = { "num_workers": self.num_workers, - "collate_fn": FrameData.collate, + "collate_fn": dataset.frame_data_type.collate, } - def train_or_val_loader( - dataset: Optional[DatasetBase], num_batches: int - ) -> Optional[torch.utils.data.DataLoader]: - if dataset is None: - return None + eval_batches = dataset.get_eval_batches() + if eval_batches is not None: + return DataLoader( + dataset, + batch_sampler=eval_batches, + **data_loader_kwargs, + ) + + scenes_matter = len(self.images_per_seq_options) > 0 + if scenes_matter and conditioning_type != BatchConditioningType.SAME: + raise ValueError( + f"{conditioning_type} cannot be used with images_per_seq " + + str(self.images_per_seq_options) + ) + + if self.batch_size == 1 or ( + not scenes_matter and conditioning_type == BatchConditioningType.SAME + ): + return self._simple_loader(dataset, num_batches, data_loader_kwargs) + + if scenes_matter: + assert conditioning_type == BatchConditioningType.SAME batch_sampler = SceneBatchSampler( dataset, self.batch_size, @@ -115,25 +320,115 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): consecutive_frames_max_gap=self.consecutive_frames_max_gap, consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds, ) - return torch.utils.data.DataLoader( + return DataLoader( dataset, batch_sampler=batch_sampler, **data_loader_kwargs, ) - train_data_loader = train_or_val_loader(datasets.train, self.dataset_len) - val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val) - - test_dataset = datasets.test - if test_dataset is not None: - test_data_loader = torch.utils.data.DataLoader( - test_dataset, - batch_sampler=test_dataset.get_eval_batches(), - **data_loader_kwargs, + if conditioning_type == BatchConditioningType.TRAIN: + return self._train_loader( + dataset, train_dataset, num_batches, data_loader_kwargs ) - else: - test_data_loader = None - return DataLoaderMap( - train=train_data_loader, val=val_data_loader, test=test_data_loader + assert conditioning_type == BatchConditioningType.KNOWN + return self._known_loader(dataset, num_batches, data_loader_kwargs) + + def _simple_loader( + self, + dataset: DatasetBase, + num_batches: int, + data_loader_kwargs: dict, + ) -> DataLoader[FrameData]: + """ + Return a simple loader for frames in the dataset. + + This is equivalent to + Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs) + except that num_batches is fixed. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + data_loader_kwargs: common args for dataloader + """ + if num_batches > 0: + num_samples = self.batch_size * num_batches + else: + num_samples = None + sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples) + batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True) + return DataLoader( + dataset, + batch_sampler=batch_sampler, + **data_loader_kwargs, + ) + + def _train_loader( + self, + dataset: DatasetBase, + train_dataset: Optional[DatasetBase], + num_batches: int, + data_loader_kwargs: dict, + ) -> DataLoader[FrameData]: + """ + Return the loader for TRAIN conditioning. + + Args: + dataset: the dataset + train_dataset: the training dataset + num_batches: possible ceiling on number of batches per epoch + data_loader_kwargs: common args for dataloader + """ + if train_dataset is None: + raise ValueError("No training data for conditioning.") + length = len(dataset) + first_indices = list(range(length)) + rest_indices = list(range(length, length + len(train_dataset))) + sampler = DoublePoolBatchSampler( + first_indices=first_indices, + rest_indices=rest_indices, + batch_size=self.batch_size, + replacement=True, + num_batches=num_batches, + ) + return DataLoader( + ChainDataset([dataset, train_dataset]), + batch_sampler=sampler, + **data_loader_kwargs, + ) + + def _known_loader( + self, + dataset: DatasetBase, + num_batches: int, + data_loader_kwargs: dict, + ) -> DataLoader[FrameData]: + """ + Return the loader for KNOWN conditioning. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + data_loader_kwargs: common args for dataloader + """ + first_indices, rest_indices = [], [] + for idx in range(len(dataset)): + frame_type = dataset[idx].frame_type + assert isinstance(frame_type, str) + if is_known_frame_scalar(frame_type): + rest_indices.append(idx) + else: + first_indices.append(idx) + sampler = DoublePoolBatchSampler( + first_indices=first_indices, + rest_indices=rest_indices, + batch_size=self.batch_size, + replacement=True, + num_batches=num_batches, + ) + return DataLoader( + dataset, + batch_sampler=sampler, + **data_loader_kwargs, ) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 4b5501ae..da2a8083 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -8,6 +8,7 @@ from collections import defaultdict from dataclasses import dataclass, field, fields from typing import ( Any, + ClassVar, Iterable, Iterator, List, @@ -15,6 +16,7 @@ from typing import ( Optional, Sequence, Tuple, + Type, Union, ) @@ -289,3 +291,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): """ for _, _, idx in self.sequence_frames_in_order(seq_name): yield idx + + # frame_data_type is the actual type of frames returned by the dataset. + # Collation uses its classmethod `collate` + frame_data_type: ClassVar[Type[FrameData]] = FrameData diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index e2057080..05252aff 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known" DATASET_TYPE_UNKNOWN = "unseen" +def is_known_frame_scalar(frame_type: str) -> bool: + """ + Given a single frame type corresponding to a single frame, return whether + the frame is a known frame. + """ + return frame_type.endswith(DATASET_TYPE_KNOWN) + + def is_known_frame( frame_type: List[str], device: Optional[str] = None ) -> torch.BoolTensor: @@ -25,7 +33,7 @@ def is_known_frame( """ # pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`. return torch.tensor( - [ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type], + [is_known_frame_scalar(ft) for ft in frame_type], dtype=torch.bool, device=device, ) diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py index 8a4be469..6d0be036 100644 --- a/pytorch3d/implicitron/dataset/visualize.py +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -69,7 +69,7 @@ def get_implicitron_sequence_pointcloud( batch_size=len(sequence_dataset), shuffle=False, num_workers=num_workers, - collate_fn=FrameData.collate, + collate_fn=dataset.frame_data_type.collate, ) frame_data = next(iter(loader)) # there's only one batch diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index 89d1d132..2efd6459 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -12,7 +12,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple import lpips import torch from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( CO3D_CATEGORIES, @@ -207,7 +207,7 @@ def _get_all_source_cameras( shuffle=False, batch_size=len(dataset_for_loader), num_workers=num_workers, - collate_fn=FrameData.collate, + collate_fn=dataset.frame_data_type.collate, ) is_known = is_known_frame(all_frame_data.frame_type) source_cameras = all_frame_data.camera[torch.where(is_known)[0]] diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 12113ff5..2c84545e 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 - dataset_len: 1000 - dataset_len_val: 1 - images_per_seq_options: - - 2 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 + train_conditioning_type: SAME + val_conditioning_type: SAME + test_conditioning_type: KNOWN + images_per_seq_options: [] sample_consecutive_frames: false consecutive_frames_max_gap: 0 consecutive_frames_max_gap_seconds: 0.1 diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index 571c95f8..2b64804c 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -8,6 +8,11 @@ import unittest from collections import defaultdict from dataclasses import dataclass +from itertools import product + +from pytorch3d.implicitron.dataset.data_loader_map_provider import ( + DoublePoolBatchSampler, +) from pytorch3d.implicitron.dataset.dataset_base import DatasetBase from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler @@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor): counter[i // divisor] += 1 return counter + + +class TestRandomSampling(unittest.TestCase): + def test_double_pool_batch_sampler(self): + unknown_idxs = [2, 3, 4, 5, 8] + known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17] + for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]): + with self.subTest(f"{replacement}, {num_batches}"): + sampler = DoublePoolBatchSampler( + first_indices=unknown_idxs, + rest_indices=known_idxs, + batch_size=4, + replacement=replacement, + num_batches=num_batches, + ) + for _ in range(6): + epoch = list(sampler) + self.assertEqual(len(epoch), num_batches or len(unknown_idxs)) + for batch in epoch: + self.assertEqual(len(batch), 4) + self.assertIn(batch[0], unknown_idxs) + for i in batch[1:]: + self.assertIn(i, known_idxs) + if not replacement and 4 != num_batches: + self.assertEqual( + {batch[0] for batch in epoch}, set(unknown_idxs) + ) diff --git a/tests/implicitron/test_data_llff.py b/tests/implicitron/test_data_llff.py index e77a1f1d..688a1bc6 100644 --- a/tests/implicitron/test_data_llff.py +++ b/tests/implicitron/test_data_llff.py @@ -10,11 +10,12 @@ import unittest from pytorch3d.implicitron.dataset.blender_dataset_map_provider import ( BlenderDatasetMapProvider, ) +from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.llff_dataset_map_provider import ( LlffDatasetMapProvider, ) -from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from tests.common_testing import TestCaseMixin @@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen") for i in batch[1:]: self.assertEqual(dataset_map.test[i].frame_type, "known") + + def test_loaders(self): + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "BlenderDatasetMapProvider" + args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider" + dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args + dataset_args.object_name = "lego" + dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego" + + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + for i in data_loaders.train: + self.assertEqual(i.frame_type, ["known"]) + self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) + for i in data_loaders.val: + self.assertEqual(i.frame_type, ["unseen"]) + self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) + for i in data_loaders.test: + self.assertEqual(i.frame_type, ["unseen"]) + self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py index e5823e05..cf5033b9 100644 --- a/tests/implicitron/test_data_source.py +++ b/tests/implicitron/test_data_source.py @@ -8,6 +8,7 @@ import os import unittest import unittest.mock +import torch from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset @@ -21,6 +22,7 @@ DEBUG: bool = False class TestDataSource(unittest.TestCase): def setUp(self): self.maxDiff = None + torch.manual_seed(42) def _test_omegaconf_generic_failure(self): # OmegaConf possible bug - this is why we need _GenericWorkaround @@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase): if DEBUG: (DATA_DIR / "data_source.yaml").write_text(yaml) self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text()) + + def test_default(self): + if os.environ.get("INSIDE_RE_WORKER") is not None: + return + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider" + dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args + dataset_args.category = "skateboard" + dataset_args.test_restrict_sequence_id = 0 + dataset_args.n_frames_per_sequence = -1 + + dataset_args.dataset_root = "manifold://co3d/tree/extracted" + + data_source = ImplicitronDataSource(**args) + _, data_loaders = data_source.get_datasets_and_dataloaders() + self.assertEqual(len(data_loaders.train), 81) + for i in data_loaders.train: + self.assertEqual(i.frame_type, ["test_known"]) + break diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index ed43dca6..73feee60 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase): frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") self.image_size = 64 + expand_args_fields(JsonIndexDataset) self.dataset = JsonIndexDataset( frame_annotations_file=frame_file, sequence_annotations_file=sequence_file,