mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
parent
c54e048666
commit
e4a3298149
@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
|
|||||||
test_on_train: false
|
test_on_train: false
|
||||||
only_test_set: false
|
only_test_set: false
|
||||||
load_eval_batches: true
|
load_eval_batches: true
|
||||||
|
num_load_workers: 4
|
||||||
n_known_frames_for_test: 0
|
n_known_frames_for_test: 0
|
||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
|||||||
"""
|
"""
|
||||||
raise ValueError("This dataset does not contain videos.")
|
raise ValueError("This dataset does not contain videos.")
|
||||||
|
|
||||||
|
def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
|
||||||
|
"""
|
||||||
|
Joins the current dataset with a list of other datasets of the same type.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return self._seq_to_idx.keys()
|
return self._seq_to_idx.keys()
|
||||||
|
|
||||||
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Returns a dict mapping from each dataset category to a list of its
|
||||||
|
sequence names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
|
||||||
|
"""
|
||||||
|
c2seq = defaultdict(list)
|
||||||
|
for sequence_name in self.sequence_names():
|
||||||
|
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
|
||||||
|
# crashes without overriding __getitem__
|
||||||
|
sequence_category = self[first_frame_idx].sequence_category
|
||||||
|
c2seq[sequence_category].append(sequence_name)
|
||||||
|
return dict(c2seq)
|
||||||
|
|
||||||
def sequence_frames_in_order(
|
def sequence_frames_in_order(
|
||||||
self, seq_name: str
|
self, seq_name: str
|
||||||
) -> Iterator[Tuple[float, int, int]]:
|
) -> Iterator[Tuple[float, int, int]]:
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, Optional
|
from typing import Iterable, Iterator, Optional
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
@ -51,6 +51,34 @@ class DatasetMap:
|
|||||||
if self.test is not None:
|
if self.test is not None:
|
||||||
yield self.test
|
yield self.test
|
||||||
|
|
||||||
|
def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
|
||||||
|
"""
|
||||||
|
Joins the current DatasetMap with other dataset maps from the input list.
|
||||||
|
|
||||||
|
For each subset of each dataset map (train/val/test), the function
|
||||||
|
omits joining the subsets that are None.
|
||||||
|
|
||||||
|
Note the train/val/test datasets of the current dataset map will be
|
||||||
|
modified in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other_dataset_maps: The list of dataset maps to be joined into the
|
||||||
|
current dataset map.
|
||||||
|
"""
|
||||||
|
for set_ in ["train", "val", "test"]:
|
||||||
|
dataset_list = [
|
||||||
|
getattr(self, set_),
|
||||||
|
*[getattr(dmap, set_) for dmap in other_dataset_maps],
|
||||||
|
]
|
||||||
|
dataset_list = [d for d in dataset_list if d is not None]
|
||||||
|
if len(dataset_list) == 0:
|
||||||
|
setattr(self, set_, None)
|
||||||
|
continue
|
||||||
|
d0 = dataset_list[0]
|
||||||
|
if len(dataset_list) > 1:
|
||||||
|
d0.join(dataset_list[1:])
|
||||||
|
setattr(self, set_, d0)
|
||||||
|
|
||||||
|
|
||||||
class DatasetMapProviderBase(ReplaceableBase):
|
class DatasetMapProviderBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
|
@ -19,6 +19,8 @@ from pathlib import Path
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self.eval_batch_index
|
self.eval_batch_index
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_filtered(self):
|
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
||||||
|
"""
|
||||||
|
Join the dataset with other JsonIndexDataset objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
other_datasets: A list of JsonIndexDataset objects to be joined
|
||||||
|
into the current dataset.
|
||||||
|
"""
|
||||||
|
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
|
||||||
|
raise ValueError("This function can only join a list of JsonIndexDataset")
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.seq_annots.update(
|
||||||
|
# https://gist.github.com/treyhunner/f35292e676efa0be1728
|
||||||
|
functools.reduce(
|
||||||
|
lambda a, b: {**a, **b},
|
||||||
|
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
all_eval_batches = [
|
||||||
|
self.eval_batches,
|
||||||
|
# pyre-ignore
|
||||||
|
*[d.eval_batches for d in other_datasets],
|
||||||
|
]
|
||||||
|
if not (
|
||||||
|
all(ba is None for ba in all_eval_batches)
|
||||||
|
or all(ba is not None for ba in all_eval_batches)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"When joining datasets, either all joined datasets have to have their"
|
||||||
|
" eval_batches defined, or all should have their eval batches undefined."
|
||||||
|
)
|
||||||
|
if self.eval_batches is not None:
|
||||||
|
self.eval_batches = sum(all_eval_batches, [])
|
||||||
|
self._invalidate_indexes(filter_seq_annots=True)
|
||||||
|
|
||||||
|
def is_filtered(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||||
stored on the disk might be missing in the dataset object.
|
stored on the disk might be missing in the dataset object.
|
||||||
@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||||
allow_missing_indices: bool = False,
|
allow_missing_indices: bool = False,
|
||||||
remove_missing_indices: bool = False,
|
remove_missing_indices: bool = False,
|
||||||
|
suppress_missing_index_warning: bool = True,
|
||||||
) -> List[List[Union[Optional[int], int]]]:
|
) -> List[List[Union[Optional[int], int]]]:
|
||||||
"""
|
"""
|
||||||
Obtain indices into the dataset object given a list of frame ids.
|
Obtain indices into the dataset object given a list of frame ids.
|
||||||
@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
If `False`, returns `None` in place of `seq_frame_index` entries that
|
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||||
are not present in the dataset.
|
are not present in the dataset.
|
||||||
If `True` removes missing indices from the returned indices.
|
If `True` removes missing indices from the returned indices.
|
||||||
|
suppress_missing_index_warning:
|
||||||
|
Active if `allow_missing_indices==True`. Suppressess a warning message
|
||||||
|
in case an entry from `seq_frame_index` is missing in the dataset
|
||||||
|
(expected in certain cases - e.g. when setting
|
||||||
|
`self.remove_empty_masks=True`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||||
@ -254,7 +299,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
if not allow_missing_indices:
|
if not allow_missing_indices:
|
||||||
raise IndexError(msg)
|
raise IndexError(msg)
|
||||||
warnings.warn(msg)
|
if not suppress_missing_index_warning:
|
||||||
|
warnings.warn(msg)
|
||||||
return idx
|
return idx
|
||||||
if path is not None:
|
if path is not None:
|
||||||
# Check that the loaded frame path is consistent
|
# Check that the loaded frame path is consistent
|
||||||
@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||||
allow_missing_indices: bool = True,
|
allow_missing_indices: bool = True,
|
||||||
) -> "JsonIndexDataset":
|
) -> "JsonIndexDataset":
|
||||||
|
"""
|
||||||
|
Generate a dataset subset given the list of frames specified in `frame_index`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_index: The list of frame indentifiers (as stored in the metadata)
|
||||||
|
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
|
||||||
|
Image paths relative to the dataset_root can be stored specified as well:
|
||||||
|
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
|
||||||
|
in the latter case, if imaga_path do not match the stored paths, an error
|
||||||
|
is raised.
|
||||||
|
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||||
|
entry from `frame_index` which is missing in the dataset.
|
||||||
|
Otherwise, generates a subset consisting of frames entries that actually
|
||||||
|
exist in the dataset.
|
||||||
|
"""
|
||||||
# Get the indices into the frame annots.
|
# Get the indices into the frame annots.
|
||||||
dataset_indices = self.seq_frame_index_to_dataset_index(
|
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||||
[frame_index],
|
[frame_index],
|
||||||
@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
|
c2seq = defaultdict(list)
|
||||||
|
# pyre-ignore
|
||||||
|
for sequence_name, sa in self.seq_annots.items():
|
||||||
|
c2seq[sa.category].append(sequence_name)
|
||||||
|
return dict(c2seq)
|
||||||
|
|
||||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
return self.eval_batches
|
return self.eval_batches
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||||
@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
(test frames can repeat across batches).
|
(test frames can repeat across batches).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
category: The object category of the dataset.
|
category: Dataset categories to load expressed as a string of comma-separated
|
||||||
|
category names (e.g. `"apple,car,orange"`).
|
||||||
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||||
e.g. "manyview_dev_0", "fewview_test", ...
|
e.g. "manyview_dev_0", "fewview_test", ...
|
||||||
dataset_root: The root folder of the dataset.
|
dataset_root: The root folder of the dataset.
|
||||||
@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
test_on_train: bool = False
|
test_on_train: bool = False
|
||||||
only_test_set: bool = False
|
only_test_set: bool = False
|
||||||
load_eval_batches: bool = True
|
load_eval_batches: bool = True
|
||||||
|
num_load_workers: int = 4
|
||||||
|
|
||||||
n_known_frames_for_test: int = 0
|
n_known_frames_for_test: int = 0
|
||||||
|
|
||||||
@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
if self.only_test_set and self.test_on_train:
|
if self.only_test_set and self.test_on_train:
|
||||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||||
|
|
||||||
frame_file = os.path.join(
|
if "," in self.category:
|
||||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
# a comma-separated list of categories to load
|
||||||
)
|
categories = [c.strip() for c in self.category.split(",")]
|
||||||
|
logger.info(f"Loading a list of categories: {str(categories)}.")
|
||||||
|
with multiprocessing.Pool(
|
||||||
|
processes=min(self.num_load_workers, len(categories))
|
||||||
|
) as pool:
|
||||||
|
category_dataset_maps = list(
|
||||||
|
tqdm(
|
||||||
|
pool.imap(self._load_category, categories),
|
||||||
|
total=len(categories),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dataset_map = category_dataset_maps[0]
|
||||||
|
dataset_map.join(category_dataset_maps[1:])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# one category to load
|
||||||
|
dataset_map = self._load_category(self.category)
|
||||||
|
|
||||||
|
self.dataset_map = dataset_map
|
||||||
|
|
||||||
|
def _load_category(self, category: str) -> DatasetMap:
|
||||||
|
|
||||||
|
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||||
sequence_file = os.path.join(
|
sequence_file = os.path.join(
|
||||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
self.dataset_root, category, "sequence_annotations.jgz"
|
||||||
)
|
)
|
||||||
|
|
||||||
path_manager = self.path_manager_factory.get()
|
path_manager = self.path_manager_factory.get()
|
||||||
@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
dataset = dataset_type(**common_dataset_kwargs)
|
dataset = dataset_type(**common_dataset_kwargs)
|
||||||
|
|
||||||
available_subset_names = self._get_available_subset_names()
|
available_subset_names = self._get_available_subset_names(category)
|
||||||
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||||
if self.subset_name not in available_subset_names:
|
if self.subset_name not in available_subset_names:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
# load the list of train/val/test frames
|
# load the list of train/val/test frames
|
||||||
subset_mapping = self._load_annotation_json(
|
subset_mapping = self._load_annotation_json(
|
||||||
os.path.join(
|
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
|
||||||
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# load the evaluation batches
|
# load the evaluation batches
|
||||||
if self.load_eval_batches:
|
if self.load_eval_batches:
|
||||||
eval_batch_index = self._load_annotation_json(
|
eval_batch_index = self._load_annotation_json(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
self.category,
|
category,
|
||||||
"eval_batches",
|
"eval_batches",
|
||||||
f"eval_batches_{self.subset_name}.json",
|
f"eval_batches_{self.subset_name}.json",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
eval_batch_index = None
|
||||||
|
|
||||||
train_dataset = None
|
train_dataset = None
|
||||||
if not self.only_test_set:
|
if not self.only_test_set:
|
||||||
@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
)
|
)
|
||||||
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||||
|
|
||||||
self.dataset_map = DatasetMap(
|
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||||
train=train_dataset, val=val_dataset, test=test_dataset
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _get_available_subset_names(self):
|
def _get_available_subset_names(self, category: str):
|
||||||
return get_available_subset_names(
|
return get_available_subset_names(
|
||||||
self.dataset_root,
|
self.dataset_root,
|
||||||
self.category,
|
category,
|
||||||
path_manager=self.path_manager_factory.get(),
|
path_manager=self.path_manager_factory.get(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -6,8 +6,9 @@
|
|||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import Counter
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Iterable, Iterator, List, Sequence, Tuple
|
from typing import Dict, Iterable, Iterator, List, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
# same but for timestamps if they are available
|
# same but for timestamps if they are available
|
||||||
consecutive_frames_max_gap_seconds: float = 0.1
|
consecutive_frames_max_gap_seconds: float = 0.1
|
||||||
|
|
||||||
|
# if True, the sampler first reads from the dataset the mapping between
|
||||||
|
# sequence names and their categories.
|
||||||
|
# During batch sampling, the sampler ensures uniform distribution over the categories
|
||||||
|
# of the sampled sequences.
|
||||||
|
category_aware: bool = True
|
||||||
|
|
||||||
seq_names: List[str] = field(init=False)
|
seq_names: List[str] = field(init=False)
|
||||||
|
|
||||||
|
category_to_sequence_names: Dict[str, List[str]] = field(init=False)
|
||||||
|
categories: List[str] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.batch_size <= 0:
|
if self.batch_size <= 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
|
|
||||||
self.seq_names = list(self.dataset.sequence_names())
|
self.seq_names = list(self.dataset.sequence_names())
|
||||||
|
|
||||||
|
if self.category_aware:
|
||||||
|
self.category_to_sequence_names = self.dataset.category_to_sequence_names()
|
||||||
|
self.categories = list(self.category_to_sequence_names.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self.num_batches
|
return self.num_batches
|
||||||
|
|
||||||
@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
def _sample_batch(self, batch_idx) -> List[int]:
|
def _sample_batch(self, batch_idx) -> List[int]:
|
||||||
n_per_seq = np.random.choice(self.images_per_seq_options)
|
n_per_seq = np.random.choice(self.images_per_seq_options)
|
||||||
n_seqs = -(-self.batch_size // n_per_seq) # round up
|
n_seqs = -(-self.batch_size // n_per_seq) # round up
|
||||||
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
|
|
||||||
|
if self.category_aware:
|
||||||
|
# first sample categories at random, these can be repeated in the batch
|
||||||
|
chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True)
|
||||||
|
# then randomly sample a set of unique sequences within each category
|
||||||
|
chosen_seq = []
|
||||||
|
for cat, n_per_category in Counter(chosen_cat).items():
|
||||||
|
category_chosen_seq = _capped_random_choice(
|
||||||
|
self.category_to_sequence_names[cat],
|
||||||
|
n_per_category,
|
||||||
|
replace=False,
|
||||||
|
)
|
||||||
|
chosen_seq.extend([str(s) for s in category_chosen_seq])
|
||||||
|
else:
|
||||||
|
chosen_seq = _capped_random_choice(
|
||||||
|
self.seq_names,
|
||||||
|
n_seqs,
|
||||||
|
replace=False,
|
||||||
|
)
|
||||||
|
|
||||||
if self.sample_consecutive_frames:
|
if self.sample_consecutive_frames:
|
||||||
frame_idx = []
|
frame_idx = []
|
||||||
|
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
|||||||
test_on_train: false
|
test_on_train: false
|
||||||
only_test_set: false
|
only_test_set: false
|
||||||
load_eval_batches: true
|
load_eval_batches: true
|
||||||
|
num_load_workers: 4
|
||||||
n_known_frames_for_test: 0
|
n_known_frames_for_test: 0
|
||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
@ -11,17 +11,20 @@ from dataclasses import dataclass
|
|||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
DoublePoolBatchSampler,
|
DoublePoolBatchSampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockFrameAnnotation:
|
class MockFrameAnnotation:
|
||||||
frame_number: int
|
frame_number: int
|
||||||
|
sequence_name: str = "sequence"
|
||||||
frame_timestamp: float = 0.0
|
frame_timestamp: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
|
|||||||
self.frame_annots = [
|
self.frame_annots = [
|
||||||
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
||||||
]
|
]
|
||||||
|
for seq_name, idx in self._seq_to_idx.items():
|
||||||
|
for i in idx:
|
||||||
|
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
|
||||||
|
|
||||||
def get_frame_numbers_and_timestamps(self, idxs):
|
def get_frame_numbers_and_timestamps(self, idxs):
|
||||||
out = []
|
out = []
|
||||||
@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def __getitem__(self, index: int):
|
||||||
|
fa = self.frame_annots[index]["frame_annotation"]
|
||||||
|
fd = FrameData(
|
||||||
|
sequence_name=fa.sequence_name,
|
||||||
|
sequence_category="default_category",
|
||||||
|
frame_number=torch.LongTensor([fa.frame_number]),
|
||||||
|
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
|
||||||
|
)
|
||||||
|
return fd
|
||||||
|
|
||||||
|
|
||||||
class TestSceneBatchSampler(unittest.TestCase):
|
class TestSceneBatchSampler(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
|||||||
categories = ["A", "B"]
|
categories = ["A", "B"]
|
||||||
subset_name = "test"
|
subset_name = "test"
|
||||||
eval_batch_size = 5
|
eval_batch_size = 5
|
||||||
|
n_frames = 8 * 3
|
||||||
|
n_sequences = 5
|
||||||
|
n_eval_batches = 10
|
||||||
with tempfile.TemporaryDirectory() as tmpd:
|
with tempfile.TemporaryDirectory() as tmpd:
|
||||||
_make_random_json_dataset_map_provider_v2_data(
|
_make_random_json_dataset_map_provider_v2_data(
|
||||||
tmpd,
|
tmpd,
|
||||||
categories,
|
categories,
|
||||||
eval_batch_size=eval_batch_size,
|
eval_batch_size=eval_batch_size,
|
||||||
|
n_frames=n_frames,
|
||||||
|
n_sequences=n_sequences,
|
||||||
|
n_eval_batches=n_eval_batches,
|
||||||
)
|
)
|
||||||
for n_known_frames_for_test in [0, 2]:
|
for n_known_frames_for_test in [0, 2]:
|
||||||
for category in categories:
|
dataset_providers = {
|
||||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
category: JsonIndexDatasetMapProviderV2(
|
||||||
category=category,
|
category=category,
|
||||||
subset_name="test",
|
subset_name="test",
|
||||||
dataset_root=tmpd,
|
dataset_root=tmpd,
|
||||||
n_known_frames_for_test=n_known_frames_for_test,
|
n_known_frames_for_test=n_known_frames_for_test,
|
||||||
)
|
)
|
||||||
|
for category in [*categories, ",".join(sorted(categories))]
|
||||||
|
}
|
||||||
|
for category, dataset_provider in dataset_providers.items():
|
||||||
dataset_map = dataset_provider.get_dataset_map()
|
dataset_map = dataset_provider.get_dataset_map()
|
||||||
for set_ in ["train", "val", "test"]:
|
for set_ in ["train", "val", "test"]:
|
||||||
|
dataset = getattr(dataset_map, set_)
|
||||||
|
|
||||||
|
cat2seq = dataset.category_to_sequence_names()
|
||||||
|
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
|
||||||
|
|
||||||
|
if not (n_known_frames_for_test != 0 and set_ == "test"):
|
||||||
|
# check the lengths only in case we do not have the
|
||||||
|
# n_known_frames_for_test set
|
||||||
|
expected_dataset_len = n_frames * n_sequences // 3
|
||||||
|
if "," in category:
|
||||||
|
# multicategory json index dataset, sum the lengths of
|
||||||
|
# category-specific ones
|
||||||
|
expected_dataset_len = sum(
|
||||||
|
len(
|
||||||
|
getattr(
|
||||||
|
dataset_providers[c].get_dataset_map(), set_
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for c in categories
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
sum(len(s) for s in cat2seq.values()),
|
||||||
|
n_sequences * len(categories),
|
||||||
|
)
|
||||||
|
self.assertEqual(len(cat2seq), len(categories))
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
len(cat2seq[category]),
|
||||||
|
n_sequences,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(cat2seq), 1)
|
||||||
|
self.assertEqual(len(dataset), expected_dataset_len)
|
||||||
|
|
||||||
|
if set_ == "test":
|
||||||
|
# check the number of eval batches
|
||||||
|
expected_n_eval_batches = n_eval_batches
|
||||||
|
if "," in category:
|
||||||
|
expected_n_eval_batches *= len(categories)
|
||||||
|
self.assertTrue(
|
||||||
|
len(dataset.get_eval_batches())
|
||||||
|
== expected_n_eval_batches
|
||||||
|
)
|
||||||
if set_ in ["train", "val"]:
|
if set_ in ["train", "val"]:
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
getattr(dataset_map, set_),
|
getattr(dataset_map, set_),
|
||||||
@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
|||||||
dataset_provider.get_category_to_subset_name_list()
|
dataset_provider.get_category_to_subset_name_list()
|
||||||
)
|
)
|
||||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||||
|
|
||||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||||
|
|
||||||
|
|
||||||
@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
|||||||
categories: List[str],
|
categories: List[str],
|
||||||
n_frames: int = 8,
|
n_frames: int = 8,
|
||||||
n_sequences: int = 5,
|
n_sequences: int = 5,
|
||||||
|
n_eval_batches: int = 10,
|
||||||
H: int = 50,
|
H: int = 50,
|
||||||
W: int = 30,
|
W: int = 30,
|
||||||
subset_name: str = "test",
|
subset_name: str = "test",
|
||||||
@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
|||||||
sequence_annotations = []
|
sequence_annotations = []
|
||||||
frame_index = []
|
frame_index = []
|
||||||
for seq_i in range(n_sequences):
|
for seq_i in range(n_sequences):
|
||||||
seq_name = str(seq_i)
|
seq_name = category + str(seq_i)
|
||||||
for i in range(n_frames):
|
for i in range(n_frames):
|
||||||
# generate and store image
|
# generate and store image
|
||||||
imdir = os.path.join(root, category, seq_name, "images")
|
imdir = os.path.join(root, category, seq_name, "images")
|
||||||
@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
|
|||||||
json.dump(set_list, f)
|
json.dump(set_list, f)
|
||||||
|
|
||||||
eval_batches = [
|
eval_batches = [
|
||||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
random.sample(test_frame_index, eval_batch_size)
|
||||||
|
for _ in range(n_eval_batches)
|
||||||
]
|
]
|
||||||
|
|
||||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user