more padding options in Dataloader

Summary: Add facilities for dataloading non-sequential scenes.

Reviewed By: shapovalov

Differential Revision: D37291277

fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 0dce883241
commit 771cf8a328
16 changed files with 449 additions and 70 deletions

View File

@ -6,22 +6,12 @@ architecture: generic
visualize_interval: 0
visdom_port: 8097
data_source_args:
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
n_frames_per_sequence: -1

View File

@ -4,8 +4,8 @@ defaults:
data_source_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
images_per_seq_options:
- 2

View File

@ -4,11 +4,9 @@ defaults:
data_source_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
images_per_seq_options:
- 2
dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: true
n_frames_per_sequence: -1

View File

@ -4,8 +4,8 @@ defaults:
data_source_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
dataset_length_train: 1000
dataset_length_val: 1
num_workers: 8
images_per_seq_options:
- 2

View File

@ -345,10 +345,13 @@ data_source_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
dataset_length_train: 0
dataset_length_val: 0
dataset_length_test: 0
train_conditioning_type: SAME
val_conditioning_type: SAME
test_conditioning_type: KNOWN
images_per_seq_options: []
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1

View File

@ -55,7 +55,7 @@ class TestExperiment(unittest.TestCase):
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
dataloader_args.dataset_len = 1
dataloader_args.dataset_length_train = 1
cfg.solver_args.max_epochs = 2
device = torch.device("cuda:0")

View File

@ -5,14 +5,23 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple
from enum import Enum
from typing import Iterator, List, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from torch.utils.data import (
BatchSampler,
ChainDataset,
DataLoader,
RandomSampler,
Sampler,
)
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler
from .utils import is_known_frame_scalar
@dataclass
@ -27,13 +36,11 @@ class DataLoaderMap:
test: a data loader for final evaluation
"""
train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]]
train: Optional[DataLoader[FrameData]]
val: Optional[DataLoader[FrameData]]
test: Optional[DataLoader[FrameData]]
def __getitem__(
self, split: str
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
"""
Get one of the data loaders by key (name of data split)
"""
@ -54,17 +61,155 @@ class DataLoaderMapProviderBase(ReplaceableBase):
raise NotImplementedError()
class DoublePoolBatchSampler(Sampler[List[int]]):
"""
Batch sampler for making random batches of a single frame
from one list and a number of known frames from another list.
"""
def __init__(
self,
first_indices: List[int],
rest_indices: List[int],
batch_size: int,
replacement: bool,
num_batches: Optional[int] = None,
) -> None:
"""
Args:
first_indices: indexes of dataset items to use as the first element
of each batch.
rest_indices: indexes of dataset items to use as the subsequent
elements of each batch. Not used if batch_size==1.
batch_size: The common size of any batch.
replacement: Whether the sampling of first items is with replacement.
num_batches: The number of batches in an epoch. If 0 or None,
one epoch is the length of `first_indices`.
"""
self.first_indices = first_indices
self.rest_indices = rest_indices
self.batch_size = batch_size
self.replacement = replacement
self.num_batches = None if num_batches == 0 else num_batches
if batch_size - 1 > len(rest_indices):
raise ValueError(
f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
)
# copied from RandomSampler
seed = int(torch.empty((), dtype=torch.int64).random_().item())
self.generator = torch.Generator()
self.generator.manual_seed(seed)
def __len__(self) -> int:
if self.num_batches is not None:
return self.num_batches
return len(self.first_indices)
def __iter__(self) -> Iterator[List[int]]:
num_batches = self.num_batches
if self.replacement:
i_first = torch.randint(
len(self.first_indices),
size=(len(self),),
generator=self.generator,
)
elif num_batches is not None:
n_copies = 1 + (num_batches - 1) // len(self.first_indices)
raw_indices = [
torch.randperm(len(self.first_indices), generator=self.generator)
for _ in range(n_copies)
]
i_first = torch.concat(raw_indices)[:num_batches]
else:
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
first_indices = [self.first_indices[i] for i in i_first]
if self.batch_size == 1:
for first_index in first_indices:
yield [first_index]
return
for first_index in first_indices:
# Consider using this class in a program which sets the seed. This use
# of randperm means that rerunning with a higher batch_size
# results in batches whose first elements as the first run.
i_rest = torch.randperm(
len(self.rest_indices),
generator=self.generator,
)[: self.batch_size - 1]
yield [first_index] + [self.rest_indices[i] for i in i_rest]
class BatchConditioningType(Enum):
"""
Ways to add conditioning frames for the val and test batches.
SAME: Use the corresponding dataset for all elements of val batches
without regard to frame type.
TRAIN: Use the corresponding dataset for the first element of each
batch, and the training dataset for the extra conditioning
elements. No regard to frame type.
KNOWN: Use frames from the corresponding dataset but separate them
according to their frame_type. Each batch will contain one UNSEEN
frame followed by many KNOWN frames.
"""
SAME = "same"
TRAIN = "train"
KNOWN = "known"
@registry.register
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
The default implementation of DataLoaderMapProviderBase.
Default implementation of DataLoaderMapProviderBase.
If a dataset returns batches from get_eval_batches(), then
they will be what the corresponding dataloader returns,
independently of any of the fields on this class.
If conditioning is not required, then the batch size should
be set as 1, and most of the fields do not matter.
If conditioning is required, each batch will contain one main
frame first to predict and the, rest of the elements are for
conditioning.
If images_per_seq_options is left empty, the conditioning
frames are picked according to the conditioning type given.
This does not have regard to the order of frames in a
scene, or which frames belong to what scene.
If images_per_seq_options is given, then the conditioning types
must be SAME and the remaining fields are used.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
num_workers: Number of data-loading threads in each data loader.
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
an epoch is the length of the training set.
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
an epoch is the length of the validation set.
dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
an epoch is the length of the test set.
train_conditioning_type: Whether the train data loader should use
only known frames for conditioning.
Only used if batch_size>1 and train dataset is
present and does not return eval_batches.
val_conditioning_type: Whether the val data loader should use
training frames or known frames for conditioning.
Only used if batch_size>1 and val dataset is
present and does not return eval_batches.
test_conditioning_type: Whether the test data loader should use
training frames or known frames for conditioning.
Only used if batch_size>1 and test dataset is
present and does not return eval_batches.
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
value. Empty (the default) means that we are not careful about which frames
come from which scene.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
@ -84,9 +229,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
batch_size: int = 1
num_workers: int = 0
dataset_len: int = 1000
dataset_len_val: int = 1
images_per_seq_options: Tuple[int, ...] = (2,)
dataset_length_train: int = 0
dataset_length_val: int = 0
dataset_length_test: int = 0
train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
images_per_seq_options: Tuple[int, ...] = ()
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
@ -95,17 +244,73 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
Returns a collection of data loaders for a given collection of datasets.
"""
return DataLoaderMap(
train=self._make_data_loader(
datasets.train,
self.dataset_length_train,
datasets.train,
self.train_conditioning_type,
),
val=self._make_data_loader(
datasets.val,
self.dataset_length_val,
datasets.train,
self.val_conditioning_type,
),
test=self._make_data_loader(
datasets.test,
self.dataset_length_test,
datasets.train,
self.test_conditioning_type,
),
)
def _make_data_loader(
self,
dataset: Optional[DatasetBase],
num_batches: int,
train_dataset: Optional[DatasetBase],
conditioning_type: BatchConditioningType,
) -> Optional[DataLoader[FrameData]]:
"""
Returns the dataloader for a dataset.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
train_dataset: the training dataset, used if conditioning_type==TRAIN
conditioning_type: source for padding of batches
"""
if dataset is None:
return None
data_loader_kwargs = {
"num_workers": self.num_workers,
"collate_fn": FrameData.collate,
"collate_fn": dataset.frame_data_type.collate,
}
def train_or_val_loader(
dataset: Optional[DatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
eval_batches = dataset.get_eval_batches()
if eval_batches is not None:
return DataLoader(
dataset,
batch_sampler=eval_batches,
**data_loader_kwargs,
)
scenes_matter = len(self.images_per_seq_options) > 0
if scenes_matter and conditioning_type != BatchConditioningType.SAME:
raise ValueError(
f"{conditioning_type} cannot be used with images_per_seq "
+ str(self.images_per_seq_options)
)
if self.batch_size == 1 or (
not scenes_matter and conditioning_type == BatchConditioningType.SAME
):
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
if scenes_matter:
assert conditioning_type == BatchConditioningType.SAME
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
@ -115,25 +320,115 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**data_loader_kwargs,
if conditioning_type == BatchConditioningType.TRAIN:
return self._train_loader(
dataset, train_dataset, num_batches, data_loader_kwargs
)
else:
test_data_loader = None
return DataLoaderMap(
train=train_data_loader, val=val_data_loader, test=test_data_loader
assert conditioning_type == BatchConditioningType.KNOWN
return self._known_loader(dataset, num_batches, data_loader_kwargs)
def _simple_loader(
self,
dataset: DatasetBase,
num_batches: int,
data_loader_kwargs: dict,
) -> DataLoader[FrameData]:
"""
Return a simple loader for frames in the dataset.
This is equivalent to
Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
except that num_batches is fixed.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
data_loader_kwargs: common args for dataloader
"""
if num_batches > 0:
num_samples = self.batch_size * num_batches
else:
num_samples = None
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
def _train_loader(
self,
dataset: DatasetBase,
train_dataset: Optional[DatasetBase],
num_batches: int,
data_loader_kwargs: dict,
) -> DataLoader[FrameData]:
"""
Return the loader for TRAIN conditioning.
Args:
dataset: the dataset
train_dataset: the training dataset
num_batches: possible ceiling on number of batches per epoch
data_loader_kwargs: common args for dataloader
"""
if train_dataset is None:
raise ValueError("No training data for conditioning.")
length = len(dataset)
first_indices = list(range(length))
rest_indices = list(range(length, length + len(train_dataset)))
sampler = DoublePoolBatchSampler(
first_indices=first_indices,
rest_indices=rest_indices,
batch_size=self.batch_size,
replacement=True,
num_batches=num_batches,
)
return DataLoader(
ChainDataset([dataset, train_dataset]),
batch_sampler=sampler,
**data_loader_kwargs,
)
def _known_loader(
self,
dataset: DatasetBase,
num_batches: int,
data_loader_kwargs: dict,
) -> DataLoader[FrameData]:
"""
Return the loader for KNOWN conditioning.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
data_loader_kwargs: common args for dataloader
"""
first_indices, rest_indices = [], []
for idx in range(len(dataset)):
frame_type = dataset[idx].frame_type
assert isinstance(frame_type, str)
if is_known_frame_scalar(frame_type):
rest_indices.append(idx)
else:
first_indices.append(idx)
sampler = DoublePoolBatchSampler(
first_indices=first_indices,
rest_indices=rest_indices,
batch_size=self.batch_size,
replacement=True,
num_batches=num_batches,
)
return DataLoader(
dataset,
batch_sampler=sampler,
**data_loader_kwargs,
)

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import (
Any,
ClassVar,
Iterable,
Iterator,
List,
@ -15,6 +16,7 @@ from typing import (
Optional,
Sequence,
Tuple,
Type,
Union,
)
@ -289,3 +291,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
"""
for _, _, idx in self.sequence_frames_in_order(seq_name):
yield idx
# frame_data_type is the actual type of frames returned by the dataset.
# Collation uses its classmethod `collate`
frame_data_type: ClassVar[Type[FrameData]] = FrameData

View File

@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen"
def is_known_frame_scalar(frame_type: str) -> bool:
"""
Given a single frame type corresponding to a single frame, return whether
the frame is a known frame.
"""
return frame_type.endswith(DATASET_TYPE_KNOWN)
def is_known_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
@ -25,7 +33,7 @@ def is_known_frame(
"""
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
return torch.tensor(
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
[is_known_frame_scalar(ft) for ft in frame_type],
dtype=torch.bool,
device=device,
)

View File

@ -69,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
batch_size=len(sequence_dataset),
shuffle=False,
num_workers=num_workers,
collate_fn=FrameData.collate,
collate_fn=dataset.frame_data_type.collate,
)
frame_data = next(iter(loader)) # there's only one batch

View File

@ -12,7 +12,7 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips
import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
@ -207,7 +207,7 @@ def _get_all_source_cameras(
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=FrameData.collate,
collate_fn=dataset.frame_data_type.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]

View File

@ -52,10 +52,13 @@ dataset_map_provider_LlffDatasetMapProvider_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
dataset_length_train: 0
dataset_length_val: 0
dataset_length_test: 0
train_conditioning_type: SAME
val_conditioning_type: SAME
test_conditioning_type: KNOWN
images_per_seq_options: []
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1

View File

@ -8,6 +8,11 @@
import unittest
from collections import defaultdict
from dataclasses import dataclass
from itertools import product
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
counter[i // divisor] += 1
return counter
class TestRandomSampling(unittest.TestCase):
def test_double_pool_batch_sampler(self):
unknown_idxs = [2, 3, 4, 5, 8]
known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
with self.subTest(f"{replacement}, {num_batches}"):
sampler = DoublePoolBatchSampler(
first_indices=unknown_idxs,
rest_indices=known_idxs,
batch_size=4,
replacement=replacement,
num_batches=num_batches,
)
for _ in range(6):
epoch = list(sampler)
self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
for batch in epoch:
self.assertEqual(len(batch), 4)
self.assertIn(batch[0], unknown_idxs)
for i in batch[1:]:
self.assertIn(i, known_idxs)
if not replacement and 4 != num_batches:
self.assertEqual(
{batch[0] for batch in epoch}, set(unknown_idxs)
)

View File

@ -10,11 +10,12 @@ import unittest
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
BlenderDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
LlffDatasetMapProvider,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from tests.common_testing import TestCaseMixin
@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
for i in batch[1:]:
self.assertEqual(dataset_map.test[i].frame_type, "known")
def test_loaders(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
dataset_args.object_name = "lego"
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["known"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
for i in data_loaders.val:
self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
for i in data_loaders.test:
self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))

View File

@ -8,6 +8,7 @@ import os
import unittest
import unittest.mock
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
@ -21,6 +22,7 @@ DEBUG: bool = False
class TestDataSource(unittest.TestCase):
def setUp(self):
self.maxDiff = None
torch.manual_seed(42)
def _test_omegaconf_generic_failure(self):
# OmegaConf possible bug - this is why we need _GenericWorkaround
@ -56,3 +58,23 @@ class TestDataSource(unittest.TestCase):
if DEBUG:
(DATA_DIR / "data_source.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
def test_default(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
return
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.n_frames_per_sequence = -1
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 81)
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["test_known"])
break

View File

@ -44,6 +44,7 @@ class TestEvaluation(unittest.TestCase):
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 64
expand_args_fields(JsonIndexDataset)
self.dataset = JsonIndexDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,