mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
parent
0dce883241
commit
771cf8a328
@ -6,22 +6,12 @@ architecture: generic
|
||||
visualize_interval: 0
|
||||
visdom_port: 8097
|
||||
data_source_args:
|
||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
n_frames_per_sequence: -1
|
||||
|
@ -4,8 +4,8 @@ defaults:
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
|
@ -4,11 +4,9 @@ defaults:
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
|
@ -4,8 +4,8 @@ defaults:
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
|
@ -345,10 +345,13 @@ data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: KNOWN
|
||||
images_per_seq_options: []
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
|
@ -55,7 +55,7 @@ class TestExperiment(unittest.TestCase):
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
||||
dataloader_args.dataset_len = 1
|
||||
dataloader_args.dataset_length_train = 1
|
||||
cfg.solver_args.max_epochs = 2
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
@ -5,14 +5,23 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from enum import Enum
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
ChainDataset,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
)
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -27,13 +36,11 @@ class DataLoaderMap:
|
||||
test: a data loader for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
val: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
test: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
train: Optional[DataLoader[FrameData]]
|
||||
val: Optional[DataLoader[FrameData]]
|
||||
test: Optional[DataLoader[FrameData]]
|
||||
|
||||
def __getitem__(
|
||||
self, split: str
|
||||
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
|
||||
def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Get one of the data loaders by key (name of data split)
|
||||
"""
|
||||
@ -54,17 +61,155 @@ class DataLoaderMapProviderBase(ReplaceableBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||
"""
|
||||
Batch sampler for making random batches of a single frame
|
||||
from one list and a number of known frames from another list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
first_indices: List[int],
|
||||
rest_indices: List[int],
|
||||
batch_size: int,
|
||||
replacement: bool,
|
||||
num_batches: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
first_indices: indexes of dataset items to use as the first element
|
||||
of each batch.
|
||||
rest_indices: indexes of dataset items to use as the subsequent
|
||||
elements of each batch. Not used if batch_size==1.
|
||||
batch_size: The common size of any batch.
|
||||
replacement: Whether the sampling of first items is with replacement.
|
||||
num_batches: The number of batches in an epoch. If 0 or None,
|
||||
one epoch is the length of `first_indices`.
|
||||
"""
|
||||
self.first_indices = first_indices
|
||||
self.rest_indices = rest_indices
|
||||
self.batch_size = batch_size
|
||||
self.replacement = replacement
|
||||
self.num_batches = None if num_batches == 0 else num_batches
|
||||
|
||||
if batch_size - 1 > len(rest_indices):
|
||||
raise ValueError(
|
||||
f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
|
||||
)
|
||||
|
||||
# copied from RandomSampler
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.num_batches is not None:
|
||||
return self.num_batches
|
||||
return len(self.first_indices)
|
||||
|
||||
def __iter__(self) -> Iterator[List[int]]:
|
||||
num_batches = self.num_batches
|
||||
if self.replacement:
|
||||
i_first = torch.randint(
|
||||
len(self.first_indices),
|
||||
size=(len(self),),
|
||||
generator=self.generator,
|
||||
)
|
||||
elif num_batches is not None:
|
||||
n_copies = 1 + (num_batches - 1) // len(self.first_indices)
|
||||
raw_indices = [
|
||||
torch.randperm(len(self.first_indices), generator=self.generator)
|
||||
for _ in range(n_copies)
|
||||
]
|
||||
i_first = torch.concat(raw_indices)[:num_batches]
|
||||
else:
|
||||
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
||||
first_indices = [self.first_indices[i] for i in i_first]
|
||||
|
||||
if self.batch_size == 1:
|
||||
for first_index in first_indices:
|
||||
yield [first_index]
|
||||
return
|
||||
|
||||
for first_index in first_indices:
|
||||
# Consider using this class in a program which sets the seed. This use
|
||||
# of randperm means that rerunning with a higher batch_size
|
||||
# results in batches whose first elements as the first run.
|
||||
i_rest = torch.randperm(
|
||||
len(self.rest_indices),
|
||||
generator=self.generator,
|
||||
)[: self.batch_size - 1]
|
||||
yield [first_index] + [self.rest_indices[i] for i in i_rest]
|
||||
|
||||
|
||||
class BatchConditioningType(Enum):
|
||||
"""
|
||||
Ways to add conditioning frames for the val and test batches.
|
||||
|
||||
SAME: Use the corresponding dataset for all elements of val batches
|
||||
without regard to frame type.
|
||||
TRAIN: Use the corresponding dataset for the first element of each
|
||||
batch, and the training dataset for the extra conditioning
|
||||
elements. No regard to frame type.
|
||||
KNOWN: Use frames from the corresponding dataset but separate them
|
||||
according to their frame_type. Each batch will contain one UNSEEN
|
||||
frame followed by many KNOWN frames.
|
||||
"""
|
||||
|
||||
SAME = "same"
|
||||
TRAIN = "train"
|
||||
KNOWN = "known"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
The default implementation of DataLoaderMapProviderBase.
|
||||
Default implementation of DataLoaderMapProviderBase.
|
||||
|
||||
If a dataset returns batches from get_eval_batches(), then
|
||||
they will be what the corresponding dataloader returns,
|
||||
independently of any of the fields on this class.
|
||||
|
||||
If conditioning is not required, then the batch size should
|
||||
be set as 1, and most of the fields do not matter.
|
||||
|
||||
If conditioning is required, each batch will contain one main
|
||||
frame first to predict and the, rest of the elements are for
|
||||
conditioning.
|
||||
|
||||
If images_per_seq_options is left empty, the conditioning
|
||||
frames are picked according to the conditioning type given.
|
||||
This does not have regard to the order of frames in a
|
||||
scene, or which frames belong to what scene.
|
||||
|
||||
If images_per_seq_options is given, then the conditioning types
|
||||
must be SAME and the remaining fields are used.
|
||||
|
||||
Members:
|
||||
batch_size: The size of the batch of the data loader.
|
||||
num_workers: Number data-loading threads.
|
||||
dataset_len: The number of batches in a training epoch.
|
||||
dataset_len_val: The number of batches in a validation epoch.
|
||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
||||
num_workers: Number of data-loading threads in each data loader.
|
||||
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||
an epoch is the length of the training set.
|
||||
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||
an epoch is the length of the validation set.
|
||||
dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
|
||||
an epoch is the length of the test set.
|
||||
train_conditioning_type: Whether the train data loader should use
|
||||
only known frames for conditioning.
|
||||
Only used if batch_size>1 and train dataset is
|
||||
present and does not return eval_batches.
|
||||
val_conditioning_type: Whether the val data loader should use
|
||||
training frames or known frames for conditioning.
|
||||
Only used if batch_size>1 and val dataset is
|
||||
present and does not return eval_batches.
|
||||
test_conditioning_type: Whether the test data loader should use
|
||||
training frames or known frames for conditioning.
|
||||
Only used if batch_size>1 and test dataset is
|
||||
present and does not return eval_batches.
|
||||
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||
value. Empty (the default) means that we are not careful about which frames
|
||||
come from which scene.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
@ -84,9 +229,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
|
||||
batch_size: int = 1
|
||||
num_workers: int = 0
|
||||
dataset_len: int = 1000
|
||||
dataset_len_val: int = 1
|
||||
images_per_seq_options: Tuple[int, ...] = (2,)
|
||||
dataset_length_train: int = 0
|
||||
dataset_length_val: int = 0
|
||||
dataset_length_test: int = 0
|
||||
train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||
val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||
test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
|
||||
images_per_seq_options: Tuple[int, ...] = ()
|
||||
sample_consecutive_frames: bool = False
|
||||
consecutive_frames_max_gap: int = 0
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
@ -95,17 +244,73 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
return DataLoaderMap(
|
||||
train=self._make_data_loader(
|
||||
datasets.train,
|
||||
self.dataset_length_train,
|
||||
datasets.train,
|
||||
self.train_conditioning_type,
|
||||
),
|
||||
val=self._make_data_loader(
|
||||
datasets.val,
|
||||
self.dataset_length_val,
|
||||
datasets.train,
|
||||
self.val_conditioning_type,
|
||||
),
|
||||
test=self._make_data_loader(
|
||||
datasets.test,
|
||||
self.dataset_length_test,
|
||||
datasets.train,
|
||||
self.test_conditioning_type,
|
||||
),
|
||||
)
|
||||
|
||||
def _make_data_loader(
|
||||
self,
|
||||
dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
train_dataset: Optional[DatasetBase],
|
||||
conditioning_type: BatchConditioningType,
|
||||
) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Returns the dataloader for a dataset.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||
conditioning_type: source for padding of batches
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
data_loader_kwargs = {
|
||||
"num_workers": self.num_workers,
|
||||
"collate_fn": FrameData.collate,
|
||||
"collate_fn": dataset.frame_data_type.collate,
|
||||
}
|
||||
|
||||
def train_or_val_loader(
|
||||
dataset: Optional[DatasetBase], num_batches: int
|
||||
) -> Optional[torch.utils.data.DataLoader]:
|
||||
if dataset is None:
|
||||
return None
|
||||
eval_batches = dataset.get_eval_batches()
|
||||
if eval_batches is not None:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=eval_batches,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
scenes_matter = len(self.images_per_seq_options) > 0
|
||||
if scenes_matter and conditioning_type != BatchConditioningType.SAME:
|
||||
raise ValueError(
|
||||
f"{conditioning_type} cannot be used with images_per_seq "
|
||||
+ str(self.images_per_seq_options)
|
||||
)
|
||||
|
||||
if self.batch_size == 1 or (
|
||||
not scenes_matter and conditioning_type == BatchConditioningType.SAME
|
||||
):
|
||||
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||
|
||||
if scenes_matter:
|
||||
assert conditioning_type == BatchConditioningType.SAME
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
@ -115,25 +320,115 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return torch.utils.data.DataLoader(
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
|
||||
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
|
||||
|
||||
test_dataset = datasets.test
|
||||
if test_dataset is not None:
|
||||
test_data_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_sampler=test_dataset.get_eval_batches(),
|
||||
**data_loader_kwargs,
|
||||
if conditioning_type == BatchConditioningType.TRAIN:
|
||||
return self._train_loader(
|
||||
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||
)
|
||||
else:
|
||||
test_data_loader = None
|
||||
|
||||
return DataLoaderMap(
|
||||
train=train_data_loader, val=val_data_loader, test=test_data_loader
|
||||
assert conditioning_type == BatchConditioningType.KNOWN
|
||||
return self._known_loader(dataset, num_batches, data_loader_kwargs)
|
||||
|
||||
def _simple_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return a simple loader for frames in the dataset.
|
||||
|
||||
This is equivalent to
|
||||
Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
|
||||
except that num_batches is fixed.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
if num_batches > 0:
|
||||
num_samples = self.batch_size * num_batches
|
||||
else:
|
||||
num_samples = None
|
||||
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
|
||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
def _train_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
train_dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return the loader for TRAIN conditioning.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
train_dataset: the training dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
if train_dataset is None:
|
||||
raise ValueError("No training data for conditioning.")
|
||||
length = len(dataset)
|
||||
first_indices = list(range(length))
|
||||
rest_indices = list(range(length, length + len(train_dataset)))
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=first_indices,
|
||||
rest_indices=rest_indices,
|
||||
batch_size=self.batch_size,
|
||||
replacement=True,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return DataLoader(
|
||||
ChainDataset([dataset, train_dataset]),
|
||||
batch_sampler=sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
def _known_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return the loader for KNOWN conditioning.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
first_indices, rest_indices = [], []
|
||||
for idx in range(len(dataset)):
|
||||
frame_type = dataset[idx].frame_type
|
||||
assert isinstance(frame_type, str)
|
||||
if is_known_frame_scalar(frame_type):
|
||||
rest_indices.append(idx)
|
||||
else:
|
||||
first_indices.append(idx)
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=first_indices,
|
||||
rest_indices=rest_indices,
|
||||
batch_size=self.batch_size,
|
||||
replacement=True,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@ -15,6 +16,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -289,3 +291,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
||||
|
||||
# frame_data_type is the actual type of frames returned by the dataset.
|
||||
# Collation uses its classmethod `collate`
|
||||
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|
||||
|
@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
|
||||
DATASET_TYPE_UNKNOWN = "unseen"
|
||||
|
||||
|
||||
def is_known_frame_scalar(frame_type: str) -> bool:
|
||||
"""
|
||||
Given a single frame type corresponding to a single frame, return whether
|
||||
the frame is a known frame.
|
||||
"""
|
||||
return frame_type.endswith(DATASET_TYPE_KNOWN)
|
||||
|
||||
|
||||
def is_known_frame(
|
||||
frame_type: List[str], device: Optional[str] = None
|
||||
) -> torch.BoolTensor:
|
||||
@ -25,7 +33,7 @@ def is_known_frame(
|
||||
"""
|
||||
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||
return torch.tensor(
|
||||
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
|
||||
[is_known_frame_scalar(ft) for ft in frame_type],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
|
@ -69,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
|
||||
batch_size=len(sequence_dataset),
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
collate_fn=dataset.frame_data_type.collate,
|
||||
)
|
||||
|
||||
frame_data = next(iter(loader)) # there's only one batch
|
||||
|
@ -12,7 +12,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
@ -207,7 +207,7 @@ def _get_all_source_cameras(
|
||||
shuffle=False,
|
||||
batch_size=len(dataset_for_loader),
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
collate_fn=dataset.frame_data_type.collate,
|
||||
)
|
||||
is_known = is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
||||
|
@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: KNOWN
|
||||
images_per_seq_options: []
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
|
@ -8,6 +8,11 @@
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
|
||||
counter[i // divisor] += 1
|
||||
|
||||
return counter
|
||||
|
||||
|
||||
class TestRandomSampling(unittest.TestCase):
|
||||
def test_double_pool_batch_sampler(self):
|
||||
unknown_idxs = [2, 3, 4, 5, 8]
|
||||
known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
||||
for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
|
||||
with self.subTest(f"{replacement}, {num_batches}"):
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=unknown_idxs,
|
||||
rest_indices=known_idxs,
|
||||
batch_size=4,
|
||||
replacement=replacement,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
for _ in range(6):
|
||||
epoch = list(sampler)
|
||||
self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
|
||||
for batch in epoch:
|
||||
self.assertEqual(len(batch), 4)
|
||||
self.assertIn(batch[0], unknown_idxs)
|
||||
for i in batch[1:]:
|
||||
self.assertIn(i, known_idxs)
|
||||
if not replacement and 4 != num_batches:
|
||||
self.assertEqual(
|
||||
{batch[0] for batch in epoch}, set(unknown_idxs)
|
||||
)
|
||||
|
@ -10,11 +10,12 @@ import unittest
|
||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
||||
BlenderDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
||||
LlffDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
|
||||
for i in batch[1:]:
|
||||
self.assertEqual(dataset_map.test[i].frame_type, "known")
|
||||
|
||||
def test_loaders(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
|
||||
dataset_args.object_name = "lego"
|
||||
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["known"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.val:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.test:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
|
@ -8,6 +8,7 @@ import os
|
||||
import unittest
|
||||
import unittest.mock
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
@ -21,6 +22,7 @@ DEBUG: bool = False
|
||||
class TestDataSource(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _test_omegaconf_generic_failure(self):
|
||||
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||
@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase):
|
||||
if DEBUG:
|
||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
||||
|
||||
def test_default(self):
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
return
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.n_frames_per_sequence = -1
|
||||
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 81)
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["test_known"])
|
||||
break
|
||||
|
@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 64
|
||||
expand_args_fields(JsonIndexDataset)
|
||||
self.dataset = JsonIndexDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
|
Loading…
x
Reference in New Issue
Block a user