mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
parent
0dce883241
commit
771cf8a328
@ -6,22 +6,12 @@ architecture: generic
|
|||||||
visualize_interval: 0
|
visualize_interval: 0
|
||||||
visdom_port: 8097
|
visdom_port: 8097
|
||||||
data_source_args:
|
data_source_args:
|
||||||
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
dataset_length_train: 1000
|
||||||
dataset_len: 1000
|
dataset_length_val: 1
|
||||||
dataset_len_val: 1
|
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
|
||||||
- 2
|
|
||||||
- 3
|
|
||||||
- 4
|
|
||||||
- 5
|
|
||||||
- 6
|
|
||||||
- 7
|
|
||||||
- 8
|
|
||||||
- 9
|
|
||||||
- 10
|
|
||||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
|
@ -4,8 +4,8 @@ defaults:
|
|||||||
data_source_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
images_per_seq_options:
|
||||||
- 2
|
- 2
|
||||||
|
@ -4,11 +4,9 @@ defaults:
|
|||||||
data_source_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
|
||||||
- 2
|
|
||||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
assert_single_seq: true
|
assert_single_seq: true
|
||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
|
@ -4,8 +4,8 @@ defaults:
|
|||||||
data_source_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
images_per_seq_options:
|
||||||
- 2
|
- 2
|
||||||
|
@ -345,10 +345,13 @@ data_source_args:
|
|||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
dataset_len: 1000
|
dataset_length_train: 0
|
||||||
dataset_len_val: 1
|
dataset_length_val: 0
|
||||||
images_per_seq_options:
|
dataset_length_test: 0
|
||||||
- 2
|
train_conditioning_type: SAME
|
||||||
|
val_conditioning_type: SAME
|
||||||
|
test_conditioning_type: KNOWN
|
||||||
|
images_per_seq_options: []
|
||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
|
@ -55,7 +55,7 @@ class TestExperiment(unittest.TestCase):
|
|||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
||||||
dataloader_args.dataset_len = 1
|
dataloader_args.dataset_length_train = 1
|
||||||
cfg.solver_args.max_epochs = 2
|
cfg.solver_args.max_epochs = 2
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
@ -5,14 +5,23 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from enum import Enum
|
||||||
|
from typing import Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
|
from torch.utils.data import (
|
||||||
|
BatchSampler,
|
||||||
|
ChainDataset,
|
||||||
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
)
|
||||||
|
|
||||||
from .dataset_base import DatasetBase, FrameData
|
from .dataset_base import DatasetBase, FrameData
|
||||||
from .dataset_map_provider import DatasetMap
|
from .dataset_map_provider import DatasetMap
|
||||||
from .scene_batch_sampler import SceneBatchSampler
|
from .scene_batch_sampler import SceneBatchSampler
|
||||||
|
from .utils import is_known_frame_scalar
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -27,13 +36,11 @@ class DataLoaderMap:
|
|||||||
test: a data loader for final evaluation
|
test: a data loader for final evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train: Optional[torch.utils.data.DataLoader[FrameData]]
|
train: Optional[DataLoader[FrameData]]
|
||||||
val: Optional[torch.utils.data.DataLoader[FrameData]]
|
val: Optional[DataLoader[FrameData]]
|
||||||
test: Optional[torch.utils.data.DataLoader[FrameData]]
|
test: Optional[DataLoader[FrameData]]
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
|
||||||
self, split: str
|
|
||||||
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
|
|
||||||
"""
|
"""
|
||||||
Get one of the data loaders by key (name of data split)
|
Get one of the data loaders by key (name of data split)
|
||||||
"""
|
"""
|
||||||
@ -54,17 +61,155 @@ class DataLoaderMapProviderBase(ReplaceableBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||||
|
"""
|
||||||
|
Batch sampler for making random batches of a single frame
|
||||||
|
from one list and a number of known frames from another list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
first_indices: List[int],
|
||||||
|
rest_indices: List[int],
|
||||||
|
batch_size: int,
|
||||||
|
replacement: bool,
|
||||||
|
num_batches: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
first_indices: indexes of dataset items to use as the first element
|
||||||
|
of each batch.
|
||||||
|
rest_indices: indexes of dataset items to use as the subsequent
|
||||||
|
elements of each batch. Not used if batch_size==1.
|
||||||
|
batch_size: The common size of any batch.
|
||||||
|
replacement: Whether the sampling of first items is with replacement.
|
||||||
|
num_batches: The number of batches in an epoch. If 0 or None,
|
||||||
|
one epoch is the length of `first_indices`.
|
||||||
|
"""
|
||||||
|
self.first_indices = first_indices
|
||||||
|
self.rest_indices = rest_indices
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.replacement = replacement
|
||||||
|
self.num_batches = None if num_batches == 0 else num_batches
|
||||||
|
|
||||||
|
if batch_size - 1 > len(rest_indices):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# copied from RandomSampler
|
||||||
|
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||||
|
self.generator = torch.Generator()
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.num_batches is not None:
|
||||||
|
return self.num_batches
|
||||||
|
return len(self.first_indices)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[List[int]]:
|
||||||
|
num_batches = self.num_batches
|
||||||
|
if self.replacement:
|
||||||
|
i_first = torch.randint(
|
||||||
|
len(self.first_indices),
|
||||||
|
size=(len(self),),
|
||||||
|
generator=self.generator,
|
||||||
|
)
|
||||||
|
elif num_batches is not None:
|
||||||
|
n_copies = 1 + (num_batches - 1) // len(self.first_indices)
|
||||||
|
raw_indices = [
|
||||||
|
torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
|
for _ in range(n_copies)
|
||||||
|
]
|
||||||
|
i_first = torch.concat(raw_indices)[:num_batches]
|
||||||
|
else:
|
||||||
|
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
|
first_indices = [self.first_indices[i] for i in i_first]
|
||||||
|
|
||||||
|
if self.batch_size == 1:
|
||||||
|
for first_index in first_indices:
|
||||||
|
yield [first_index]
|
||||||
|
return
|
||||||
|
|
||||||
|
for first_index in first_indices:
|
||||||
|
# Consider using this class in a program which sets the seed. This use
|
||||||
|
# of randperm means that rerunning with a higher batch_size
|
||||||
|
# results in batches whose first elements as the first run.
|
||||||
|
i_rest = torch.randperm(
|
||||||
|
len(self.rest_indices),
|
||||||
|
generator=self.generator,
|
||||||
|
)[: self.batch_size - 1]
|
||||||
|
yield [first_index] + [self.rest_indices[i] for i in i_rest]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchConditioningType(Enum):
|
||||||
|
"""
|
||||||
|
Ways to add conditioning frames for the val and test batches.
|
||||||
|
|
||||||
|
SAME: Use the corresponding dataset for all elements of val batches
|
||||||
|
without regard to frame type.
|
||||||
|
TRAIN: Use the corresponding dataset for the first element of each
|
||||||
|
batch, and the training dataset for the extra conditioning
|
||||||
|
elements. No regard to frame type.
|
||||||
|
KNOWN: Use frames from the corresponding dataset but separate them
|
||||||
|
according to their frame_type. Each batch will contain one UNSEEN
|
||||||
|
frame followed by many KNOWN frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAME = "same"
|
||||||
|
TRAIN = "train"
|
||||||
|
KNOWN = "known"
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||||
"""
|
"""
|
||||||
The default implementation of DataLoaderMapProviderBase.
|
Default implementation of DataLoaderMapProviderBase.
|
||||||
|
|
||||||
|
If a dataset returns batches from get_eval_batches(), then
|
||||||
|
they will be what the corresponding dataloader returns,
|
||||||
|
independently of any of the fields on this class.
|
||||||
|
|
||||||
|
If conditioning is not required, then the batch size should
|
||||||
|
be set as 1, and most of the fields do not matter.
|
||||||
|
|
||||||
|
If conditioning is required, each batch will contain one main
|
||||||
|
frame first to predict and the, rest of the elements are for
|
||||||
|
conditioning.
|
||||||
|
|
||||||
|
If images_per_seq_options is left empty, the conditioning
|
||||||
|
frames are picked according to the conditioning type given.
|
||||||
|
This does not have regard to the order of frames in a
|
||||||
|
scene, or which frames belong to what scene.
|
||||||
|
|
||||||
|
If images_per_seq_options is given, then the conditioning types
|
||||||
|
must be SAME and the remaining fields are used.
|
||||||
|
|
||||||
Members:
|
Members:
|
||||||
batch_size: The size of the batch of the data loader.
|
batch_size: The size of the batch of the data loader.
|
||||||
num_workers: Number data-loading threads.
|
num_workers: Number of data-loading threads in each data loader.
|
||||||
dataset_len: The number of batches in a training epoch.
|
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||||
dataset_len_val: The number of batches in a validation epoch.
|
an epoch is the length of the training set.
|
||||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the validation set.
|
||||||
|
dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the test set.
|
||||||
|
train_conditioning_type: Whether the train data loader should use
|
||||||
|
only known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and train dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
val_conditioning_type: Whether the val data loader should use
|
||||||
|
training frames or known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and val dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
test_conditioning_type: Whether the test data loader should use
|
||||||
|
training frames or known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and test dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||||
|
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||||
|
value. Empty (the default) means that we are not careful about which frames
|
||||||
|
come from which scene.
|
||||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||||
in the sequence. It first sorts the frames by timestimps when available,
|
in the sequence. It first sorts the frames by timestimps when available,
|
||||||
otherwise by frame numbers, finds the connected segments within the sequence
|
otherwise by frame numbers, finds the connected segments within the sequence
|
||||||
@ -84,9 +229,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
|
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
num_workers: int = 0
|
num_workers: int = 0
|
||||||
dataset_len: int = 1000
|
dataset_length_train: int = 0
|
||||||
dataset_len_val: int = 1
|
dataset_length_val: int = 0
|
||||||
images_per_seq_options: Tuple[int, ...] = (2,)
|
dataset_length_test: int = 0
|
||||||
|
train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||||
|
val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||||
|
test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
|
||||||
|
images_per_seq_options: Tuple[int, ...] = ()
|
||||||
sample_consecutive_frames: bool = False
|
sample_consecutive_frames: bool = False
|
||||||
consecutive_frames_max_gap: int = 0
|
consecutive_frames_max_gap: int = 0
|
||||||
consecutive_frames_max_gap_seconds: float = 0.1
|
consecutive_frames_max_gap_seconds: float = 0.1
|
||||||
@ -95,17 +244,73 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
"""
|
"""
|
||||||
Returns a collection of data loaders for a given collection of datasets.
|
Returns a collection of data loaders for a given collection of datasets.
|
||||||
"""
|
"""
|
||||||
|
return DataLoaderMap(
|
||||||
|
train=self._make_data_loader(
|
||||||
|
datasets.train,
|
||||||
|
self.dataset_length_train,
|
||||||
|
datasets.train,
|
||||||
|
self.train_conditioning_type,
|
||||||
|
),
|
||||||
|
val=self._make_data_loader(
|
||||||
|
datasets.val,
|
||||||
|
self.dataset_length_val,
|
||||||
|
datasets.train,
|
||||||
|
self.val_conditioning_type,
|
||||||
|
),
|
||||||
|
test=self._make_data_loader(
|
||||||
|
datasets.test,
|
||||||
|
self.dataset_length_test,
|
||||||
|
datasets.train,
|
||||||
|
self.test_conditioning_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_data_loader(
|
||||||
|
self,
|
||||||
|
dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
train_dataset: Optional[DatasetBase],
|
||||||
|
conditioning_type: BatchConditioningType,
|
||||||
|
) -> Optional[DataLoader[FrameData]]:
|
||||||
|
"""
|
||||||
|
Returns the dataloader for a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||||
|
conditioning_type: source for padding of batches
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
data_loader_kwargs = {
|
data_loader_kwargs = {
|
||||||
"num_workers": self.num_workers,
|
"num_workers": self.num_workers,
|
||||||
"collate_fn": FrameData.collate,
|
"collate_fn": dataset.frame_data_type.collate,
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_or_val_loader(
|
eval_batches = dataset.get_eval_batches()
|
||||||
dataset: Optional[DatasetBase], num_batches: int
|
if eval_batches is not None:
|
||||||
) -> Optional[torch.utils.data.DataLoader]:
|
return DataLoader(
|
||||||
if dataset is None:
|
dataset,
|
||||||
return None
|
batch_sampler=eval_batches,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
scenes_matter = len(self.images_per_seq_options) > 0
|
||||||
|
if scenes_matter and conditioning_type != BatchConditioningType.SAME:
|
||||||
|
raise ValueError(
|
||||||
|
f"{conditioning_type} cannot be used with images_per_seq "
|
||||||
|
+ str(self.images_per_seq_options)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.batch_size == 1 or (
|
||||||
|
not scenes_matter and conditioning_type == BatchConditioningType.SAME
|
||||||
|
):
|
||||||
|
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||||
|
|
||||||
|
if scenes_matter:
|
||||||
|
assert conditioning_type == BatchConditioningType.SAME
|
||||||
batch_sampler = SceneBatchSampler(
|
batch_sampler = SceneBatchSampler(
|
||||||
dataset,
|
dataset,
|
||||||
self.batch_size,
|
self.batch_size,
|
||||||
@ -115,25 +320,115 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||||
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||||
)
|
)
|
||||||
return torch.utils.data.DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_sampler=batch_sampler,
|
batch_sampler=batch_sampler,
|
||||||
**data_loader_kwargs,
|
**data_loader_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
|
if conditioning_type == BatchConditioningType.TRAIN:
|
||||||
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
|
return self._train_loader(
|
||||||
|
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
test_dataset = datasets.test
|
assert conditioning_type == BatchConditioningType.KNOWN
|
||||||
if test_dataset is not None:
|
return self._known_loader(dataset, num_batches, data_loader_kwargs)
|
||||||
test_data_loader = torch.utils.data.DataLoader(
|
|
||||||
test_dataset,
|
def _simple_loader(
|
||||||
batch_sampler=test_dataset.get_eval_batches(),
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return a simple loader for frames in the dataset.
|
||||||
|
|
||||||
|
This is equivalent to
|
||||||
|
Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
|
||||||
|
except that num_batches is fixed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
if num_batches > 0:
|
||||||
|
num_samples = self.batch_size * num_batches
|
||||||
|
else:
|
||||||
|
num_samples = None
|
||||||
|
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
|
||||||
|
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
**data_loader_kwargs,
|
**data_loader_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
test_data_loader = None
|
|
||||||
|
|
||||||
return DataLoaderMap(
|
def _train_loader(
|
||||||
train=train_data_loader, val=val_data_loader, test=test_data_loader
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
train_dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return the loader for TRAIN conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
train_dataset: the training dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
if train_dataset is None:
|
||||||
|
raise ValueError("No training data for conditioning.")
|
||||||
|
length = len(dataset)
|
||||||
|
first_indices = list(range(length))
|
||||||
|
rest_indices = list(range(length, length + len(train_dataset)))
|
||||||
|
sampler = DoublePoolBatchSampler(
|
||||||
|
first_indices=first_indices,
|
||||||
|
rest_indices=rest_indices,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
replacement=True,
|
||||||
|
num_batches=num_batches,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
ChainDataset([dataset, train_dataset]),
|
||||||
|
batch_sampler=sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _known_loader(
|
||||||
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return the loader for KNOWN conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
first_indices, rest_indices = [], []
|
||||||
|
for idx in range(len(dataset)):
|
||||||
|
frame_type = dataset[idx].frame_type
|
||||||
|
assert isinstance(frame_type, str)
|
||||||
|
if is_known_frame_scalar(frame_type):
|
||||||
|
rest_indices.append(idx)
|
||||||
|
else:
|
||||||
|
first_indices.append(idx)
|
||||||
|
sampler = DoublePoolBatchSampler(
|
||||||
|
first_indices=first_indices,
|
||||||
|
rest_indices=rest_indices,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
replacement=True,
|
||||||
|
num_batches=num_batches,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,7 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
ClassVar,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -15,6 +16,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -289,3 +291,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
|||||||
"""
|
"""
|
||||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||||
yield idx
|
yield idx
|
||||||
|
|
||||||
|
# frame_data_type is the actual type of frames returned by the dataset.
|
||||||
|
# Collation uses its classmethod `collate`
|
||||||
|
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|
||||||
|
@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
|
|||||||
DATASET_TYPE_UNKNOWN = "unseen"
|
DATASET_TYPE_UNKNOWN = "unseen"
|
||||||
|
|
||||||
|
|
||||||
|
def is_known_frame_scalar(frame_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Given a single frame type corresponding to a single frame, return whether
|
||||||
|
the frame is a known frame.
|
||||||
|
"""
|
||||||
|
return frame_type.endswith(DATASET_TYPE_KNOWN)
|
||||||
|
|
||||||
|
|
||||||
def is_known_frame(
|
def is_known_frame(
|
||||||
frame_type: List[str], device: Optional[str] = None
|
frame_type: List[str], device: Optional[str] = None
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
@ -25,7 +33,7 @@ def is_known_frame(
|
|||||||
"""
|
"""
|
||||||
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
|
[is_known_frame_scalar(ft) for ft in frame_type],
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
@ -69,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
|
|||||||
batch_size=len(sequence_dataset),
|
batch_size=len(sequence_dataset),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=FrameData.collate,
|
collate_fn=dataset.frame_data_type.collate,
|
||||||
)
|
)
|
||||||
|
|
||||||
frame_data = next(iter(loader)) # there's only one batch
|
frame_data = next(iter(loader)) # there's only one batch
|
||||||
|
@ -12,7 +12,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple
|
|||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||||
CO3D_CATEGORIES,
|
CO3D_CATEGORIES,
|
||||||
@ -207,7 +207,7 @@ def _get_all_source_cameras(
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=len(dataset_for_loader),
|
batch_size=len(dataset_for_loader),
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=FrameData.collate,
|
collate_fn=dataset.frame_data_type.collate,
|
||||||
)
|
)
|
||||||
is_known = is_known_frame(all_frame_data.frame_type)
|
is_known = is_known_frame(all_frame_data.frame_type)
|
||||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
||||||
|
@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args:
|
|||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
dataset_len: 1000
|
dataset_length_train: 0
|
||||||
dataset_len_val: 1
|
dataset_length_val: 0
|
||||||
images_per_seq_options:
|
dataset_length_test: 0
|
||||||
- 2
|
train_conditioning_type: SAME
|
||||||
|
val_conditioning_type: SAME
|
||||||
|
test_conditioning_type: KNOWN
|
||||||
|
images_per_seq_options: []
|
||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
|
@ -8,6 +8,11 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
|
DoublePoolBatchSampler,
|
||||||
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
|
|||||||
counter[i // divisor] += 1
|
counter[i // divisor] += 1
|
||||||
|
|
||||||
return counter
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandomSampling(unittest.TestCase):
|
||||||
|
def test_double_pool_batch_sampler(self):
|
||||||
|
unknown_idxs = [2, 3, 4, 5, 8]
|
||||||
|
known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
||||||
|
for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
|
||||||
|
with self.subTest(f"{replacement}, {num_batches}"):
|
||||||
|
sampler = DoublePoolBatchSampler(
|
||||||
|
first_indices=unknown_idxs,
|
||||||
|
rest_indices=known_idxs,
|
||||||
|
batch_size=4,
|
||||||
|
replacement=replacement,
|
||||||
|
num_batches=num_batches,
|
||||||
|
)
|
||||||
|
for _ in range(6):
|
||||||
|
epoch = list(sampler)
|
||||||
|
self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
|
||||||
|
for batch in epoch:
|
||||||
|
self.assertEqual(len(batch), 4)
|
||||||
|
self.assertIn(batch[0], unknown_idxs)
|
||||||
|
for i in batch[1:]:
|
||||||
|
self.assertIn(i, known_idxs)
|
||||||
|
if not replacement and 4 != num_batches:
|
||||||
|
self.assertEqual(
|
||||||
|
{batch[0] for batch in epoch}, set(unknown_idxs)
|
||||||
|
)
|
||||||
|
@ -10,11 +10,12 @@ import unittest
|
|||||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
||||||
BlenderDatasetMapProvider,
|
BlenderDatasetMapProvider,
|
||||||
)
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
||||||
LlffDatasetMapProvider,
|
LlffDatasetMapProvider,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
|
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
|
||||||
for i in batch[1:]:
|
for i in batch[1:]:
|
||||||
self.assertEqual(dataset_map.test[i].frame_type, "known")
|
self.assertEqual(dataset_map.test[i].frame_type, "known")
|
||||||
|
|
||||||
|
def test_loaders(self):
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
|
||||||
|
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
|
||||||
|
dataset_args.object_name = "lego"
|
||||||
|
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
for i in data_loaders.train:
|
||||||
|
self.assertEqual(i.frame_type, ["known"])
|
||||||
|
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||||
|
for i in data_loaders.val:
|
||||||
|
self.assertEqual(i.frame_type, ["unseen"])
|
||||||
|
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||||
|
for i in data_loaders.test:
|
||||||
|
self.assertEqual(i.frame_type, ["unseen"])
|
||||||
|
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||||
|
@ -8,6 +8,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
|
|
||||||
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
@ -21,6 +22,7 @@ DEBUG: bool = False
|
|||||||
class TestDataSource(unittest.TestCase):
|
class TestDataSource(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
def _test_omegaconf_generic_failure(self):
|
def _test_omegaconf_generic_failure(self):
|
||||||
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||||
@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase):
|
|||||||
if DEBUG:
|
if DEBUG:
|
||||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
||||||
|
|
||||||
|
def test_default(self):
|
||||||
|
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||||
|
return
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
||||||
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
dataset_args.category = "skateboard"
|
||||||
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
|
dataset_args.n_frames_per_sequence = -1
|
||||||
|
|
||||||
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
self.assertEqual(len(data_loaders.train), 81)
|
||||||
|
for i in data_loaders.train:
|
||||||
|
self.assertEqual(i.frame_type, ["test_known"])
|
||||||
|
break
|
||||||
|
@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||||
self.image_size = 64
|
self.image_size = 64
|
||||||
|
expand_args_fields(JsonIndexDataset)
|
||||||
self.dataset = JsonIndexDataset(
|
self.dataset = JsonIndexDataset(
|
||||||
frame_annotations_file=frame_file,
|
frame_annotations_file=frame_file,
|
||||||
sequence_annotations_file=sequence_file,
|
sequence_annotations_file=sequence_file,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user