mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
		
							parent
							
								
									0dce883241
								
							
						
					
					
						commit
						771cf8a328
					
				@ -6,22 +6,12 @@ architecture: generic
 | 
			
		||||
visualize_interval: 0
 | 
			
		||||
visdom_port: 8097
 | 
			
		||||
data_source_args:
 | 
			
		||||
  data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
			
		||||
  dataset_map_provider_class_type: JsonIndexDatasetMapProvider
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 10
 | 
			
		||||
    dataset_len: 1000
 | 
			
		||||
    dataset_len_val: 1
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
    images_per_seq_options:
 | 
			
		||||
    - 2
 | 
			
		||||
    - 3
 | 
			
		||||
    - 4
 | 
			
		||||
    - 5
 | 
			
		||||
    - 6
 | 
			
		||||
    - 7
 | 
			
		||||
    - 8
 | 
			
		||||
    - 9
 | 
			
		||||
    - 10
 | 
			
		||||
  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
			
		||||
    dataset_root: ${oc.env:CO3D_DATASET_ROOT}
 | 
			
		||||
    n_frames_per_sequence: -1
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,8 @@ defaults:
 | 
			
		||||
data_source_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 10
 | 
			
		||||
    dataset_len: 1000
 | 
			
		||||
    dataset_len_val: 1
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
    images_per_seq_options:
 | 
			
		||||
    - 2
 | 
			
		||||
 | 
			
		||||
@ -4,11 +4,9 @@ defaults:
 | 
			
		||||
data_source_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 1
 | 
			
		||||
    dataset_len: 1000
 | 
			
		||||
    dataset_len_val: 1
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
    images_per_seq_options:
 | 
			
		||||
    - 2
 | 
			
		||||
  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
			
		||||
    assert_single_seq: true
 | 
			
		||||
    n_frames_per_sequence: -1
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,8 @@ defaults:
 | 
			
		||||
data_source_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 10
 | 
			
		||||
    dataset_len: 1000
 | 
			
		||||
    dataset_len_val: 1
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
    images_per_seq_options:
 | 
			
		||||
    - 2
 | 
			
		||||
 | 
			
		||||
@ -345,10 +345,13 @@ data_source_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 1
 | 
			
		||||
    num_workers: 0
 | 
			
		||||
    dataset_len: 1000
 | 
			
		||||
    dataset_len_val: 1
 | 
			
		||||
    images_per_seq_options:
 | 
			
		||||
    - 2
 | 
			
		||||
    dataset_length_train: 0
 | 
			
		||||
    dataset_length_val: 0
 | 
			
		||||
    dataset_length_test: 0
 | 
			
		||||
    train_conditioning_type: SAME
 | 
			
		||||
    val_conditioning_type: SAME
 | 
			
		||||
    test_conditioning_type: KNOWN
 | 
			
		||||
    images_per_seq_options: []
 | 
			
		||||
    sample_consecutive_frames: false
 | 
			
		||||
    consecutive_frames_max_gap: 0
 | 
			
		||||
    consecutive_frames_max_gap_seconds: 0.1
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
        dataset_args.test_restrict_sequence_id = 0
 | 
			
		||||
        dataset_args.dataset_root = "manifold://co3d/tree/extracted"
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
 | 
			
		||||
        dataloader_args.dataset_len = 1
 | 
			
		||||
        dataloader_args.dataset_length_train = 1
 | 
			
		||||
        cfg.solver_args.max_epochs = 2
 | 
			
		||||
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
@ -5,14 +5,23 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Iterator, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from torch.utils.data import (
 | 
			
		||||
    BatchSampler,
 | 
			
		||||
    ChainDataset,
 | 
			
		||||
    DataLoader,
 | 
			
		||||
    RandomSampler,
 | 
			
		||||
    Sampler,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .dataset_map_provider import DatasetMap
 | 
			
		||||
from .scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
from .utils import is_known_frame_scalar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -27,13 +36,11 @@ class DataLoaderMap:
 | 
			
		||||
        test: a data loader for final evaluation
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    train: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
    val: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
    test: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
    train: Optional[DataLoader[FrameData]]
 | 
			
		||||
    val: Optional[DataLoader[FrameData]]
 | 
			
		||||
    test: Optional[DataLoader[FrameData]]
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self, split: str
 | 
			
		||||
    ) -> Optional[torch.utils.data.DataLoader[FrameData]]:
 | 
			
		||||
    def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
 | 
			
		||||
        """
 | 
			
		||||
        Get one of the data loaders by key (name of data split)
 | 
			
		||||
        """
 | 
			
		||||
@ -54,17 +61,155 @@ class DataLoaderMapProviderBase(ReplaceableBase):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DoublePoolBatchSampler(Sampler[List[int]]):
 | 
			
		||||
    """
 | 
			
		||||
    Batch sampler for making random batches of a single frame
 | 
			
		||||
    from one list and a number of known frames from another list.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        first_indices: List[int],
 | 
			
		||||
        rest_indices: List[int],
 | 
			
		||||
        batch_size: int,
 | 
			
		||||
        replacement: bool,
 | 
			
		||||
        num_batches: Optional[int] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            first_indices: indexes of dataset items to use as the first element
 | 
			
		||||
                        of each batch.
 | 
			
		||||
            rest_indices: indexes of dataset items to use as the subsequent
 | 
			
		||||
                        elements of each batch. Not used if batch_size==1.
 | 
			
		||||
            batch_size: The common size of any batch.
 | 
			
		||||
            replacement: Whether the sampling of first items is with replacement.
 | 
			
		||||
            num_batches: The number of batches in an epoch. If 0 or None,
 | 
			
		||||
                        one epoch is the length of `first_indices`.
 | 
			
		||||
        """
 | 
			
		||||
        self.first_indices = first_indices
 | 
			
		||||
        self.rest_indices = rest_indices
 | 
			
		||||
        self.batch_size = batch_size
 | 
			
		||||
        self.replacement = replacement
 | 
			
		||||
        self.num_batches = None if num_batches == 0 else num_batches
 | 
			
		||||
 | 
			
		||||
        if batch_size - 1 > len(rest_indices):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # copied from RandomSampler
 | 
			
		||||
        seed = int(torch.empty((), dtype=torch.int64).random_().item())
 | 
			
		||||
        self.generator = torch.Generator()
 | 
			
		||||
        self.generator.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        if self.num_batches is not None:
 | 
			
		||||
            return self.num_batches
 | 
			
		||||
        return len(self.first_indices)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator[List[int]]:
 | 
			
		||||
        num_batches = self.num_batches
 | 
			
		||||
        if self.replacement:
 | 
			
		||||
            i_first = torch.randint(
 | 
			
		||||
                len(self.first_indices),
 | 
			
		||||
                size=(len(self),),
 | 
			
		||||
                generator=self.generator,
 | 
			
		||||
            )
 | 
			
		||||
        elif num_batches is not None:
 | 
			
		||||
            n_copies = 1 + (num_batches - 1) // len(self.first_indices)
 | 
			
		||||
            raw_indices = [
 | 
			
		||||
                torch.randperm(len(self.first_indices), generator=self.generator)
 | 
			
		||||
                for _ in range(n_copies)
 | 
			
		||||
            ]
 | 
			
		||||
            i_first = torch.concat(raw_indices)[:num_batches]
 | 
			
		||||
        else:
 | 
			
		||||
            i_first = torch.randperm(len(self.first_indices), generator=self.generator)
 | 
			
		||||
        first_indices = [self.first_indices[i] for i in i_first]
 | 
			
		||||
 | 
			
		||||
        if self.batch_size == 1:
 | 
			
		||||
            for first_index in first_indices:
 | 
			
		||||
                yield [first_index]
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        for first_index in first_indices:
 | 
			
		||||
            # Consider using this class in a program which sets the seed. This use
 | 
			
		||||
            # of randperm means that rerunning with a higher batch_size
 | 
			
		||||
            # results in batches whose first elements as the first run.
 | 
			
		||||
            i_rest = torch.randperm(
 | 
			
		||||
                len(self.rest_indices),
 | 
			
		||||
                generator=self.generator,
 | 
			
		||||
            )[: self.batch_size - 1]
 | 
			
		||||
            yield [first_index] + [self.rest_indices[i] for i in i_rest]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BatchConditioningType(Enum):
 | 
			
		||||
    """
 | 
			
		||||
    Ways to add conditioning frames for the val and test batches.
 | 
			
		||||
 | 
			
		||||
    SAME: Use the corresponding dataset for all elements of val batches
 | 
			
		||||
        without regard to frame type.
 | 
			
		||||
    TRAIN: Use the corresponding dataset for the first element of each
 | 
			
		||||
        batch, and the training dataset for the extra conditioning
 | 
			
		||||
            elements. No regard to frame type.
 | 
			
		||||
    KNOWN: Use frames from the corresponding dataset but separate them
 | 
			
		||||
        according to their frame_type. Each batch will contain one UNSEEN
 | 
			
		||||
        frame followed by many KNOWN frames.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    SAME = "same"
 | 
			
		||||
    TRAIN = "train"
 | 
			
		||||
    KNOWN = "known"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
 | 
			
		||||
    """
 | 
			
		||||
    The default implementation of DataLoaderMapProviderBase.
 | 
			
		||||
    Default implementation of DataLoaderMapProviderBase.
 | 
			
		||||
 | 
			
		||||
    If a dataset returns batches from get_eval_batches(), then
 | 
			
		||||
    they will be what the corresponding dataloader returns,
 | 
			
		||||
    independently of any of the fields on this class.
 | 
			
		||||
 | 
			
		||||
    If conditioning is not required, then the batch size should
 | 
			
		||||
    be set as 1, and most of the fields do not matter.
 | 
			
		||||
 | 
			
		||||
    If conditioning is required, each batch will contain one main
 | 
			
		||||
    frame first to predict and the, rest of the elements are for
 | 
			
		||||
    conditioning.
 | 
			
		||||
 | 
			
		||||
    If images_per_seq_options is left empty, the conditioning
 | 
			
		||||
    frames are picked according to the conditioning type given.
 | 
			
		||||
    This does not have regard to the order of frames in a
 | 
			
		||||
    scene, or which frames belong to what scene.
 | 
			
		||||
 | 
			
		||||
    If images_per_seq_options is given, then the conditioning types
 | 
			
		||||
    must be SAME and the remaining fields are used.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        batch_size: The size of the batch of the data loader.
 | 
			
		||||
        num_workers: Number data-loading threads.
 | 
			
		||||
        dataset_len: The number of batches in a training epoch.
 | 
			
		||||
        dataset_len_val: The number of batches in a validation epoch.
 | 
			
		||||
        images_per_seq_options: Possible numbers of images sampled per sequence.
 | 
			
		||||
        num_workers: Number of data-loading threads in each data loader.
 | 
			
		||||
        dataset_length_train: The number of batches in a training epoch. Or 0 to mean
 | 
			
		||||
            an epoch is the length of the training set.
 | 
			
		||||
        dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
 | 
			
		||||
            an epoch is the length of the validation set.
 | 
			
		||||
        dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
 | 
			
		||||
            an epoch is the length of the test set.
 | 
			
		||||
        train_conditioning_type: Whether the train data loader should use
 | 
			
		||||
            only known frames for conditioning.
 | 
			
		||||
            Only used if batch_size>1 and train dataset is
 | 
			
		||||
            present and does not return eval_batches.
 | 
			
		||||
        val_conditioning_type: Whether the val data loader should use
 | 
			
		||||
            training frames or known frames for conditioning.
 | 
			
		||||
            Only used if batch_size>1 and val dataset is
 | 
			
		||||
            present and does not return eval_batches.
 | 
			
		||||
        test_conditioning_type: Whether the test data loader should use
 | 
			
		||||
            training frames or known frames for conditioning.
 | 
			
		||||
            Only used if batch_size>1 and test dataset is
 | 
			
		||||
            present and does not return eval_batches.
 | 
			
		||||
        images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
 | 
			
		||||
            If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
 | 
			
		||||
            value. Empty (the default) means that we are not careful about which frames
 | 
			
		||||
            come from which scene.
 | 
			
		||||
        sample_consecutive_frames: if True, will sample a contiguous interval of frames
 | 
			
		||||
            in the sequence. It first sorts the frames by timestimps when available,
 | 
			
		||||
            otherwise by frame numbers, finds the connected segments within the sequence
 | 
			
		||||
@ -84,9 +229,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
 | 
			
		||||
 | 
			
		||||
    batch_size: int = 1
 | 
			
		||||
    num_workers: int = 0
 | 
			
		||||
    dataset_len: int = 1000
 | 
			
		||||
    dataset_len_val: int = 1
 | 
			
		||||
    images_per_seq_options: Tuple[int, ...] = (2,)
 | 
			
		||||
    dataset_length_train: int = 0
 | 
			
		||||
    dataset_length_val: int = 0
 | 
			
		||||
    dataset_length_test: int = 0
 | 
			
		||||
    train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
 | 
			
		||||
    val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
 | 
			
		||||
    test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
 | 
			
		||||
    images_per_seq_options: Tuple[int, ...] = ()
 | 
			
		||||
    sample_consecutive_frames: bool = False
 | 
			
		||||
    consecutive_frames_max_gap: int = 0
 | 
			
		||||
    consecutive_frames_max_gap_seconds: float = 0.1
 | 
			
		||||
@ -95,17 +244,73 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a collection of data loaders for a given collection of datasets.
 | 
			
		||||
        """
 | 
			
		||||
        return DataLoaderMap(
 | 
			
		||||
            train=self._make_data_loader(
 | 
			
		||||
                datasets.train,
 | 
			
		||||
                self.dataset_length_train,
 | 
			
		||||
                datasets.train,
 | 
			
		||||
                self.train_conditioning_type,
 | 
			
		||||
            ),
 | 
			
		||||
            val=self._make_data_loader(
 | 
			
		||||
                datasets.val,
 | 
			
		||||
                self.dataset_length_val,
 | 
			
		||||
                datasets.train,
 | 
			
		||||
                self.val_conditioning_type,
 | 
			
		||||
            ),
 | 
			
		||||
            test=self._make_data_loader(
 | 
			
		||||
                datasets.test,
 | 
			
		||||
                self.dataset_length_test,
 | 
			
		||||
                datasets.train,
 | 
			
		||||
                self.test_conditioning_type,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _make_data_loader(
 | 
			
		||||
        self,
 | 
			
		||||
        dataset: Optional[DatasetBase],
 | 
			
		||||
        num_batches: int,
 | 
			
		||||
        train_dataset: Optional[DatasetBase],
 | 
			
		||||
        conditioning_type: BatchConditioningType,
 | 
			
		||||
    ) -> Optional[DataLoader[FrameData]]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the dataloader for a dataset.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dataset: the dataset
 | 
			
		||||
            num_batches: possible ceiling on number of batches per epoch
 | 
			
		||||
            train_dataset: the training dataset, used if conditioning_type==TRAIN
 | 
			
		||||
            conditioning_type: source for padding of batches
 | 
			
		||||
        """
 | 
			
		||||
        if dataset is None:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        data_loader_kwargs = {
 | 
			
		||||
            "num_workers": self.num_workers,
 | 
			
		||||
            "collate_fn": FrameData.collate,
 | 
			
		||||
            "collate_fn": dataset.frame_data_type.collate,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        def train_or_val_loader(
 | 
			
		||||
            dataset: Optional[DatasetBase], num_batches: int
 | 
			
		||||
        ) -> Optional[torch.utils.data.DataLoader]:
 | 
			
		||||
            if dataset is None:
 | 
			
		||||
                return None
 | 
			
		||||
        eval_batches = dataset.get_eval_batches()
 | 
			
		||||
        if eval_batches is not None:
 | 
			
		||||
            return DataLoader(
 | 
			
		||||
                dataset,
 | 
			
		||||
                batch_sampler=eval_batches,
 | 
			
		||||
                **data_loader_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        scenes_matter = len(self.images_per_seq_options) > 0
 | 
			
		||||
        if scenes_matter and conditioning_type != BatchConditioningType.SAME:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"{conditioning_type} cannot be used with images_per_seq "
 | 
			
		||||
                + str(self.images_per_seq_options)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.batch_size == 1 or (
 | 
			
		||||
            not scenes_matter and conditioning_type == BatchConditioningType.SAME
 | 
			
		||||
        ):
 | 
			
		||||
            return self._simple_loader(dataset, num_batches, data_loader_kwargs)
 | 
			
		||||
 | 
			
		||||
        if scenes_matter:
 | 
			
		||||
            assert conditioning_type == BatchConditioningType.SAME
 | 
			
		||||
            batch_sampler = SceneBatchSampler(
 | 
			
		||||
                dataset,
 | 
			
		||||
                self.batch_size,
 | 
			
		||||
@ -115,25 +320,115 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
 | 
			
		||||
                consecutive_frames_max_gap=self.consecutive_frames_max_gap,
 | 
			
		||||
                consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
 | 
			
		||||
            )
 | 
			
		||||
            return torch.utils.data.DataLoader(
 | 
			
		||||
            return DataLoader(
 | 
			
		||||
                dataset,
 | 
			
		||||
                batch_sampler=batch_sampler,
 | 
			
		||||
                **data_loader_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
 | 
			
		||||
        val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
 | 
			
		||||
 | 
			
		||||
        test_dataset = datasets.test
 | 
			
		||||
        if test_dataset is not None:
 | 
			
		||||
            test_data_loader = torch.utils.data.DataLoader(
 | 
			
		||||
                test_dataset,
 | 
			
		||||
                batch_sampler=test_dataset.get_eval_batches(),
 | 
			
		||||
                **data_loader_kwargs,
 | 
			
		||||
        if conditioning_type == BatchConditioningType.TRAIN:
 | 
			
		||||
            return self._train_loader(
 | 
			
		||||
                dataset, train_dataset, num_batches, data_loader_kwargs
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            test_data_loader = None
 | 
			
		||||
 | 
			
		||||
        return DataLoaderMap(
 | 
			
		||||
            train=train_data_loader, val=val_data_loader, test=test_data_loader
 | 
			
		||||
        assert conditioning_type == BatchConditioningType.KNOWN
 | 
			
		||||
        return self._known_loader(dataset, num_batches, data_loader_kwargs)
 | 
			
		||||
 | 
			
		||||
    def _simple_loader(
 | 
			
		||||
        self,
 | 
			
		||||
        dataset: DatasetBase,
 | 
			
		||||
        num_batches: int,
 | 
			
		||||
        data_loader_kwargs: dict,
 | 
			
		||||
    ) -> DataLoader[FrameData]:
 | 
			
		||||
        """
 | 
			
		||||
        Return a simple loader for frames in the dataset.
 | 
			
		||||
 | 
			
		||||
        This is equivalent to
 | 
			
		||||
            Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
 | 
			
		||||
        except that num_batches is fixed.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dataset: the dataset
 | 
			
		||||
            num_batches: possible ceiling on number of batches per epoch
 | 
			
		||||
            data_loader_kwargs: common args for dataloader
 | 
			
		||||
        """
 | 
			
		||||
        if num_batches > 0:
 | 
			
		||||
            num_samples = self.batch_size * num_batches
 | 
			
		||||
        else:
 | 
			
		||||
            num_samples = None
 | 
			
		||||
        sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
 | 
			
		||||
        batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
 | 
			
		||||
        return DataLoader(
 | 
			
		||||
            dataset,
 | 
			
		||||
            batch_sampler=batch_sampler,
 | 
			
		||||
            **data_loader_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _train_loader(
 | 
			
		||||
        self,
 | 
			
		||||
        dataset: DatasetBase,
 | 
			
		||||
        train_dataset: Optional[DatasetBase],
 | 
			
		||||
        num_batches: int,
 | 
			
		||||
        data_loader_kwargs: dict,
 | 
			
		||||
    ) -> DataLoader[FrameData]:
 | 
			
		||||
        """
 | 
			
		||||
        Return the loader for TRAIN conditioning.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dataset: the dataset
 | 
			
		||||
            train_dataset: the training dataset
 | 
			
		||||
            num_batches: possible ceiling on number of batches per epoch
 | 
			
		||||
            data_loader_kwargs: common args for dataloader
 | 
			
		||||
        """
 | 
			
		||||
        if train_dataset is None:
 | 
			
		||||
            raise ValueError("No training data for conditioning.")
 | 
			
		||||
        length = len(dataset)
 | 
			
		||||
        first_indices = list(range(length))
 | 
			
		||||
        rest_indices = list(range(length, length + len(train_dataset)))
 | 
			
		||||
        sampler = DoublePoolBatchSampler(
 | 
			
		||||
            first_indices=first_indices,
 | 
			
		||||
            rest_indices=rest_indices,
 | 
			
		||||
            batch_size=self.batch_size,
 | 
			
		||||
            replacement=True,
 | 
			
		||||
            num_batches=num_batches,
 | 
			
		||||
        )
 | 
			
		||||
        return DataLoader(
 | 
			
		||||
            ChainDataset([dataset, train_dataset]),
 | 
			
		||||
            batch_sampler=sampler,
 | 
			
		||||
            **data_loader_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _known_loader(
 | 
			
		||||
        self,
 | 
			
		||||
        dataset: DatasetBase,
 | 
			
		||||
        num_batches: int,
 | 
			
		||||
        data_loader_kwargs: dict,
 | 
			
		||||
    ) -> DataLoader[FrameData]:
 | 
			
		||||
        """
 | 
			
		||||
        Return the loader for KNOWN conditioning.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            dataset: the dataset
 | 
			
		||||
            num_batches: possible ceiling on number of batches per epoch
 | 
			
		||||
            data_loader_kwargs: common args for dataloader
 | 
			
		||||
        """
 | 
			
		||||
        first_indices, rest_indices = [], []
 | 
			
		||||
        for idx in range(len(dataset)):
 | 
			
		||||
            frame_type = dataset[idx].frame_type
 | 
			
		||||
            assert isinstance(frame_type, str)
 | 
			
		||||
            if is_known_frame_scalar(frame_type):
 | 
			
		||||
                rest_indices.append(idx)
 | 
			
		||||
            else:
 | 
			
		||||
                first_indices.append(idx)
 | 
			
		||||
        sampler = DoublePoolBatchSampler(
 | 
			
		||||
            first_indices=first_indices,
 | 
			
		||||
            rest_indices=rest_indices,
 | 
			
		||||
            batch_size=self.batch_size,
 | 
			
		||||
            replacement=True,
 | 
			
		||||
            num_batches=num_batches,
 | 
			
		||||
        )
 | 
			
		||||
        return DataLoader(
 | 
			
		||||
            dataset,
 | 
			
		||||
            batch_sampler=sampler,
 | 
			
		||||
            **data_loader_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass, field, fields
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    List,
 | 
			
		||||
@ -15,6 +16,7 @@ from typing import (
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Tuple,
 | 
			
		||||
    Type,
 | 
			
		||||
    Union,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -289,3 +291,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        """
 | 
			
		||||
        for _, _, idx in self.sequence_frames_in_order(seq_name):
 | 
			
		||||
            yield idx
 | 
			
		||||
 | 
			
		||||
    # frame_data_type is the actual type of frames returned by the dataset.
 | 
			
		||||
    # Collation uses its classmethod `collate`
 | 
			
		||||
    frame_data_type: ClassVar[Type[FrameData]] = FrameData
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
 | 
			
		||||
DATASET_TYPE_UNKNOWN = "unseen"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_known_frame_scalar(frame_type: str) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Given a single frame type corresponding to a single frame, return whether
 | 
			
		||||
    the frame is a known frame.
 | 
			
		||||
    """
 | 
			
		||||
    return frame_type.endswith(DATASET_TYPE_KNOWN)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_known_frame(
 | 
			
		||||
    frame_type: List[str], device: Optional[str] = None
 | 
			
		||||
) -> torch.BoolTensor:
 | 
			
		||||
@ -25,7 +33,7 @@ def is_known_frame(
 | 
			
		||||
    """
 | 
			
		||||
    # pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
 | 
			
		||||
    return torch.tensor(
 | 
			
		||||
        [ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
 | 
			
		||||
        [is_known_frame_scalar(ft) for ft in frame_type],
 | 
			
		||||
        dtype=torch.bool,
 | 
			
		||||
        device=device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
 | 
			
		||||
        batch_size=len(sequence_dataset),
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        num_workers=num_workers,
 | 
			
		||||
        collate_fn=FrameData.collate,
 | 
			
		||||
        collate_fn=dataset.frame_data_type.collate,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    frame_data = next(iter(loader))  # there's only one batch
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple
 | 
			
		||||
import lpips
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
 | 
			
		||||
    CO3D_CATEGORIES,
 | 
			
		||||
@ -207,7 +207,7 @@ def _get_all_source_cameras(
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        batch_size=len(dataset_for_loader),
 | 
			
		||||
        num_workers=num_workers,
 | 
			
		||||
        collate_fn=FrameData.collate,
 | 
			
		||||
        collate_fn=dataset.frame_data_type.collate,
 | 
			
		||||
    )
 | 
			
		||||
    is_known = is_known_frame(all_frame_data.frame_type)
 | 
			
		||||
    source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
 | 
			
		||||
 | 
			
		||||
@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args:
 | 
			
		||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
  batch_size: 1
 | 
			
		||||
  num_workers: 0
 | 
			
		||||
  dataset_len: 1000
 | 
			
		||||
  dataset_len_val: 1
 | 
			
		||||
  images_per_seq_options:
 | 
			
		||||
  - 2
 | 
			
		||||
  dataset_length_train: 0
 | 
			
		||||
  dataset_length_val: 0
 | 
			
		||||
  dataset_length_test: 0
 | 
			
		||||
  train_conditioning_type: SAME
 | 
			
		||||
  val_conditioning_type: SAME
 | 
			
		||||
  test_conditioning_type: KNOWN
 | 
			
		||||
  images_per_seq_options: []
 | 
			
		||||
  sample_consecutive_frames: false
 | 
			
		||||
  consecutive_frames_max_gap: 0
 | 
			
		||||
  consecutive_frames_max_gap_seconds: 0.1
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,11 @@
 | 
			
		||||
import unittest
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from itertools import product
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
 | 
			
		||||
    DoublePoolBatchSampler,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
 | 
			
		||||
        counter[i // divisor] += 1
 | 
			
		||||
 | 
			
		||||
    return counter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRandomSampling(unittest.TestCase):
 | 
			
		||||
    def test_double_pool_batch_sampler(self):
 | 
			
		||||
        unknown_idxs = [2, 3, 4, 5, 8]
 | 
			
		||||
        known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
 | 
			
		||||
        for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
 | 
			
		||||
            with self.subTest(f"{replacement}, {num_batches}"):
 | 
			
		||||
                sampler = DoublePoolBatchSampler(
 | 
			
		||||
                    first_indices=unknown_idxs,
 | 
			
		||||
                    rest_indices=known_idxs,
 | 
			
		||||
                    batch_size=4,
 | 
			
		||||
                    replacement=replacement,
 | 
			
		||||
                    num_batches=num_batches,
 | 
			
		||||
                )
 | 
			
		||||
                for _ in range(6):
 | 
			
		||||
                    epoch = list(sampler)
 | 
			
		||||
                    self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
 | 
			
		||||
                    for batch in epoch:
 | 
			
		||||
                        self.assertEqual(len(batch), 4)
 | 
			
		||||
                        self.assertIn(batch[0], unknown_idxs)
 | 
			
		||||
                        for i in batch[1:]:
 | 
			
		||||
                            self.assertIn(i, known_idxs)
 | 
			
		||||
                    if not replacement and 4 != num_batches:
 | 
			
		||||
                        self.assertEqual(
 | 
			
		||||
                            {batch[0] for batch in epoch}, set(unknown_idxs)
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
@ -10,11 +10,12 @@ import unittest
 | 
			
		||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
 | 
			
		||||
    BlenderDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
 | 
			
		||||
    LlffDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
 | 
			
		||||
            for i in batch[1:]:
 | 
			
		||||
                self.assertEqual(dataset_map.test[i].frame_type, "known")
 | 
			
		||||
 | 
			
		||||
    def test_loaders(self):
 | 
			
		||||
        args = get_default_args(ImplicitronDataSource)
 | 
			
		||||
        args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
 | 
			
		||||
        args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
 | 
			
		||||
        dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
 | 
			
		||||
        dataset_args.object_name = "lego"
 | 
			
		||||
        dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
 | 
			
		||||
 | 
			
		||||
        data_source = ImplicitronDataSource(**args)
 | 
			
		||||
        _, data_loaders = data_source.get_datasets_and_dataloaders()
 | 
			
		||||
        for i in data_loaders.train:
 | 
			
		||||
            self.assertEqual(i.frame_type, ["known"])
 | 
			
		||||
            self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
 | 
			
		||||
        for i in data_loaders.val:
 | 
			
		||||
            self.assertEqual(i.frame_type, ["unseen"])
 | 
			
		||||
            self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
 | 
			
		||||
        for i in data_loaders.test:
 | 
			
		||||
            self.assertEqual(i.frame_type, ["unseen"])
 | 
			
		||||
            self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import os
 | 
			
		||||
import unittest
 | 
			
		||||
import unittest.mock
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
@ -21,6 +22,7 @@ DEBUG: bool = False
 | 
			
		||||
class TestDataSource(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.maxDiff = None
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
    def _test_omegaconf_generic_failure(self):
 | 
			
		||||
        # OmegaConf possible bug - this is why we need _GenericWorkaround
 | 
			
		||||
@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase):
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                (DATA_DIR / "data_source.yaml").write_text(yaml)
 | 
			
		||||
            self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
 | 
			
		||||
 | 
			
		||||
    def test_default(self):
 | 
			
		||||
        if os.environ.get("INSIDE_RE_WORKER") is not None:
 | 
			
		||||
            return
 | 
			
		||||
        args = get_default_args(ImplicitronDataSource)
 | 
			
		||||
        args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
 | 
			
		||||
        args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
 | 
			
		||||
        dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
			
		||||
        dataset_args.category = "skateboard"
 | 
			
		||||
        dataset_args.test_restrict_sequence_id = 0
 | 
			
		||||
        dataset_args.n_frames_per_sequence = -1
 | 
			
		||||
 | 
			
		||||
        dataset_args.dataset_root = "manifold://co3d/tree/extracted"
 | 
			
		||||
 | 
			
		||||
        data_source = ImplicitronDataSource(**args)
 | 
			
		||||
        _, data_loaders = data_source.get_datasets_and_dataloaders()
 | 
			
		||||
        self.assertEqual(len(data_loaders.train), 81)
 | 
			
		||||
        for i in data_loaders.train:
 | 
			
		||||
            self.assertEqual(i.frame_type, ["test_known"])
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase):
 | 
			
		||||
        frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
        sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
 | 
			
		||||
        self.image_size = 64
 | 
			
		||||
        expand_args_fields(JsonIndexDataset)
 | 
			
		||||
        self.dataset = JsonIndexDataset(
 | 
			
		||||
            frame_annotations_file=frame_file,
 | 
			
		||||
            sequence_annotations_file=sequence_file,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user