mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
parent
c54e048666
commit
e4a3298149
@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
num_load_workers: 4
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
|
@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
|
||||
"""
|
||||
Joins the current dataset with a list of other datasets of the same type.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
# pyre-ignore[16]
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns a dict mapping from each dataset category to a list of its
|
||||
sequence names.
|
||||
|
||||
Returns:
|
||||
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
|
||||
"""
|
||||
c2seq = defaultdict(list)
|
||||
for sequence_name in self.sequence_names():
|
||||
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
|
||||
# crashes without overriding __getitem__
|
||||
sequence_category = self[first_frame_idx].sequence_category
|
||||
c2seq[sequence_category].append(sequence_name)
|
||||
return dict(c2seq)
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
|
@ -7,7 +7,7 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterable, Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
@ -51,6 +51,34 @@ class DatasetMap:
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
|
||||
"""
|
||||
Joins the current DatasetMap with other dataset maps from the input list.
|
||||
|
||||
For each subset of each dataset map (train/val/test), the function
|
||||
omits joining the subsets that are None.
|
||||
|
||||
Note the train/val/test datasets of the current dataset map will be
|
||||
modified in-place.
|
||||
|
||||
Args:
|
||||
other_dataset_maps: The list of dataset maps to be joined into the
|
||||
current dataset map.
|
||||
"""
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataset_list = [
|
||||
getattr(self, set_),
|
||||
*[getattr(dmap, set_) for dmap in other_dataset_maps],
|
||||
]
|
||||
dataset_list = [d for d in dataset_list if d is not None]
|
||||
if len(dataset_list) == 0:
|
||||
setattr(self, set_, None)
|
||||
continue
|
||||
d0 = dataset_list[0]
|
||||
if len(dataset_list) > 1:
|
||||
d0.join(dataset_list[1:])
|
||||
setattr(self, set_, d0)
|
||||
|
||||
|
||||
class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
|
@ -19,6 +19,8 @@ from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.eval_batch_index
|
||||
)
|
||||
|
||||
def is_filtered(self):
|
||||
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
||||
"""
|
||||
Join the dataset with other JsonIndexDataset objects.
|
||||
|
||||
Args:
|
||||
other_datasets: A list of JsonIndexDataset objects to be joined
|
||||
into the current dataset.
|
||||
"""
|
||||
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
|
||||
raise ValueError("This function can only join a list of JsonIndexDataset")
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots.update(
|
||||
# https://gist.github.com/treyhunner/f35292e676efa0be1728
|
||||
functools.reduce(
|
||||
lambda a, b: {**a, **b},
|
||||
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
|
||||
)
|
||||
)
|
||||
all_eval_batches = [
|
||||
self.eval_batches,
|
||||
# pyre-ignore
|
||||
*[d.eval_batches for d in other_datasets],
|
||||
]
|
||||
if not (
|
||||
all(ba is None for ba in all_eval_batches)
|
||||
or all(ba is not None for ba in all_eval_batches)
|
||||
):
|
||||
raise ValueError(
|
||||
"When joining datasets, either all joined datasets have to have their"
|
||||
" eval_batches defined, or all should have their eval batches undefined."
|
||||
)
|
||||
if self.eval_batches is not None:
|
||||
self.eval_batches = sum(all_eval_batches, [])
|
||||
self._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
def is_filtered(self) -> bool:
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||
stored on the disk might be missing in the dataset object.
|
||||
@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
allow_missing_indices: bool = False,
|
||||
remove_missing_indices: bool = False,
|
||||
suppress_missing_index_warning: bool = True,
|
||||
) -> List[List[Union[Optional[int], int]]]:
|
||||
"""
|
||||
Obtain indices into the dataset object given a list of frame ids.
|
||||
@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||
are not present in the dataset.
|
||||
If `True` removes missing indices from the returned indices.
|
||||
suppress_missing_index_warning:
|
||||
Active if `allow_missing_indices==True`. Suppressess a warning message
|
||||
in case an entry from `seq_frame_index` is missing in the dataset
|
||||
(expected in certain cases - e.g. when setting
|
||||
`self.remove_empty_masks=True`).
|
||||
|
||||
Returns:
|
||||
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||
@ -254,6 +299,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
if not allow_missing_indices:
|
||||
raise IndexError(msg)
|
||||
if not suppress_missing_index_warning:
|
||||
warnings.warn(msg)
|
||||
return idx
|
||||
if path is not None:
|
||||
@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||
allow_missing_indices: bool = True,
|
||||
) -> "JsonIndexDataset":
|
||||
"""
|
||||
Generate a dataset subset given the list of frames specified in `frame_index`.
|
||||
|
||||
Args:
|
||||
frame_index: The list of frame indentifiers (as stored in the metadata)
|
||||
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
|
||||
Image paths relative to the dataset_root can be stored specified as well:
|
||||
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
|
||||
in the latter case, if imaga_path do not match the stored paths, an error
|
||||
is raised.
|
||||
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||
entry from `frame_index` which is missing in the dataset.
|
||||
Otherwise, generates a subset consisting of frames entries that actually
|
||||
exist in the dataset.
|
||||
"""
|
||||
# Get the indices into the frame annots.
|
||||
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||
[frame_index],
|
||||
@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
return out
|
||||
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
c2seq = defaultdict(list)
|
||||
# pyre-ignore
|
||||
for sequence_name, sa in self.seq_annots.items():
|
||||
c2seq[sa.category].append(sequence_name)
|
||||
return dict(c2seq)
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return self.eval_batches
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
)
|
||||
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||
@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
(test frames can repeat across batches).
|
||||
|
||||
Args:
|
||||
category: The object category of the dataset.
|
||||
category: Dataset categories to load expressed as a string of comma-separated
|
||||
category names (e.g. `"apple,car,orange"`).
|
||||
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||
e.g. "manyview_dev_0", "fewview_test", ...
|
||||
dataset_root: The root folder of the dataset.
|
||||
@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
test_on_train: bool = False
|
||||
only_test_set: bool = False
|
||||
load_eval_batches: bool = True
|
||||
num_load_workers: int = 4
|
||||
|
||||
n_known_frames_for_test: int = 0
|
||||
|
||||
@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
frame_file = os.path.join(
|
||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||
if "," in self.category:
|
||||
# a comma-separated list of categories to load
|
||||
categories = [c.strip() for c in self.category.split(",")]
|
||||
logger.info(f"Loading a list of categories: {str(categories)}.")
|
||||
with multiprocessing.Pool(
|
||||
processes=min(self.num_load_workers, len(categories))
|
||||
) as pool:
|
||||
category_dataset_maps = list(
|
||||
tqdm(
|
||||
pool.imap(self._load_category, categories),
|
||||
total=len(categories),
|
||||
)
|
||||
)
|
||||
dataset_map = category_dataset_maps[0]
|
||||
dataset_map.join(category_dataset_maps[1:])
|
||||
|
||||
else:
|
||||
# one category to load
|
||||
dataset_map = self._load_category(self.category)
|
||||
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
)
|
||||
|
||||
path_manager = self.path_manager_factory.get()
|
||||
@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
dataset = dataset_type(**common_dataset_kwargs)
|
||||
|
||||
available_subset_names = self._get_available_subset_names()
|
||||
available_subset_names = self._get_available_subset_names(category)
|
||||
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||
if self.subset_name not in available_subset_names:
|
||||
raise ValueError(
|
||||
@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
# load the list of train/val/test frames
|
||||
subset_mapping = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
||||
)
|
||||
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
if self.load_eval_batches:
|
||||
eval_batch_index = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category,
|
||||
category,
|
||||
"eval_batches",
|
||||
f"eval_batches_{self.subset_name}.json",
|
||||
)
|
||||
)
|
||||
else:
|
||||
eval_batch_index = None
|
||||
|
||||
train_dataset = None
|
||||
if not self.only_test_set:
|
||||
@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
)
|
||||
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||
|
||||
self.dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
def _get_available_subset_names(self):
|
||||
def _get_available_subset_names(self, category: str):
|
||||
return get_available_subset_names(
|
||||
self.dataset_root,
|
||||
self.category,
|
||||
category,
|
||||
path_manager=self.path_manager_factory.get(),
|
||||
)
|
||||
|
||||
|
@ -6,8 +6,9 @@
|
||||
|
||||
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
from typing import Dict, Iterable, Iterator, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
# same but for timestamps if they are available
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
|
||||
# if True, the sampler first reads from the dataset the mapping between
|
||||
# sequence names and their categories.
|
||||
# During batch sampling, the sampler ensures uniform distribution over the categories
|
||||
# of the sampled sequences.
|
||||
category_aware: bool = True
|
||||
|
||||
seq_names: List[str] = field(init=False)
|
||||
|
||||
category_to_sequence_names: Dict[str, List[str]] = field(init=False)
|
||||
categories: List[str] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size <= 0:
|
||||
raise ValueError(
|
||||
@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
|
||||
self.seq_names = list(self.dataset.sequence_names())
|
||||
|
||||
if self.category_aware:
|
||||
self.category_to_sequence_names = self.dataset.category_to_sequence_names()
|
||||
self.categories = list(self.category_to_sequence_names.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_batches
|
||||
|
||||
@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
def _sample_batch(self, batch_idx) -> List[int]:
|
||||
n_per_seq = np.random.choice(self.images_per_seq_options)
|
||||
n_seqs = -(-self.batch_size // n_per_seq) # round up
|
||||
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
|
||||
|
||||
if self.category_aware:
|
||||
# first sample categories at random, these can be repeated in the batch
|
||||
chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True)
|
||||
# then randomly sample a set of unique sequences within each category
|
||||
chosen_seq = []
|
||||
for cat, n_per_category in Counter(chosen_cat).items():
|
||||
category_chosen_seq = _capped_random_choice(
|
||||
self.category_to_sequence_names[cat],
|
||||
n_per_category,
|
||||
replace=False,
|
||||
)
|
||||
chosen_seq.extend([str(s) for s in category_chosen_seq])
|
||||
else:
|
||||
chosen_seq = _capped_random_choice(
|
||||
self.seq_names,
|
||||
n_seqs,
|
||||
replace=False,
|
||||
)
|
||||
|
||||
if self.sample_consecutive_frames:
|
||||
frame_idx = []
|
||||
|
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
num_load_workers: 4
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
|
@ -11,17 +11,20 @@ from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFrameAnnotation:
|
||||
frame_number: int
|
||||
sequence_name: str = "sequence"
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
|
||||
self.frame_annots = [
|
||||
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
||||
]
|
||||
for seq_name, idx in self._seq_to_idx.items():
|
||||
for i in idx:
|
||||
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
|
||||
|
||||
def get_frame_numbers_and_timestamps(self, idxs):
|
||||
out = []
|
||||
@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
|
||||
)
|
||||
return out
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
fa = self.frame_annots[index]["frame_annotation"]
|
||||
fd = FrameData(
|
||||
sequence_name=fa.sequence_name,
|
||||
sequence_category="default_category",
|
||||
frame_number=torch.LongTensor([fa.frame_number]),
|
||||
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
|
||||
)
|
||||
return fd
|
||||
|
||||
|
||||
class TestSceneBatchSampler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
eval_batch_size = 5
|
||||
n_frames = 8 * 3
|
||||
n_sequences = 5
|
||||
n_eval_batches = 10
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(
|
||||
tmpd,
|
||||
categories,
|
||||
eval_batch_size=eval_batch_size,
|
||||
n_frames=n_frames,
|
||||
n_sequences=n_sequences,
|
||||
n_eval_batches=n_eval_batches,
|
||||
)
|
||||
for n_known_frames_for_test in [0, 2]:
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
dataset_providers = {
|
||||
category: JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
n_known_frames_for_test=n_known_frames_for_test,
|
||||
)
|
||||
for category in [*categories, ",".join(sorted(categories))]
|
||||
}
|
||||
for category, dataset_provider in dataset_providers.items():
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataset = getattr(dataset_map, set_)
|
||||
|
||||
cat2seq = dataset.category_to_sequence_names()
|
||||
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
|
||||
|
||||
if not (n_known_frames_for_test != 0 and set_ == "test"):
|
||||
# check the lengths only in case we do not have the
|
||||
# n_known_frames_for_test set
|
||||
expected_dataset_len = n_frames * n_sequences // 3
|
||||
if "," in category:
|
||||
# multicategory json index dataset, sum the lengths of
|
||||
# category-specific ones
|
||||
expected_dataset_len = sum(
|
||||
len(
|
||||
getattr(
|
||||
dataset_providers[c].get_dataset_map(), set_
|
||||
)
|
||||
)
|
||||
for c in categories
|
||||
)
|
||||
self.assertEqual(
|
||||
sum(len(s) for s in cat2seq.values()),
|
||||
n_sequences * len(categories),
|
||||
)
|
||||
self.assertEqual(len(cat2seq), len(categories))
|
||||
else:
|
||||
self.assertEqual(
|
||||
len(cat2seq[category]),
|
||||
n_sequences,
|
||||
)
|
||||
self.assertEqual(len(cat2seq), 1)
|
||||
self.assertEqual(len(dataset), expected_dataset_len)
|
||||
|
||||
if set_ == "test":
|
||||
# check the number of eval batches
|
||||
expected_n_eval_batches = n_eval_batches
|
||||
if "," in category:
|
||||
expected_n_eval_batches *= len(categories)
|
||||
self.assertTrue(
|
||||
len(dataset.get_eval_batches())
|
||||
== expected_n_eval_batches
|
||||
)
|
||||
if set_ in ["train", "val"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
|
||||
|
||||
@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
categories: List[str],
|
||||
n_frames: int = 8,
|
||||
n_sequences: int = 5,
|
||||
n_eval_batches: int = 10,
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
sequence_annotations = []
|
||||
frame_index = []
|
||||
for seq_i in range(n_sequences):
|
||||
seq_name = str(seq_i)
|
||||
seq_name = category + str(seq_i)
|
||||
for i in range(n_frames):
|
||||
# generate and store image
|
||||
imdir = os.path.join(root, category, seq_name, "images")
|
||||
@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [
|
||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||
random.sample(test_frame_index, eval_batch_size)
|
||||
for _ in range(n_eval_batches)
|
||||
]
|
||||
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
|
Loading…
x
Reference in New Issue
Block a user