mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c54e048666
commit
e4a3298149
@@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
|
||||
"""
|
||||
Joins the current dataset with a list of other datasets of the same type.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
@@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
# pyre-ignore[16]
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns a dict mapping from each dataset category to a list of its
|
||||
sequence names.
|
||||
|
||||
Returns:
|
||||
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
|
||||
"""
|
||||
c2seq = defaultdict(list)
|
||||
for sequence_name in self.sequence_names():
|
||||
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
|
||||
# crashes without overriding __getitem__
|
||||
sequence_category = self[first_frame_idx].sequence_category
|
||||
c2seq[sequence_category].append(sequence_name)
|
||||
return dict(c2seq)
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterable, Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
@@ -51,6 +51,34 @@ class DatasetMap:
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
|
||||
"""
|
||||
Joins the current DatasetMap with other dataset maps from the input list.
|
||||
|
||||
For each subset of each dataset map (train/val/test), the function
|
||||
omits joining the subsets that are None.
|
||||
|
||||
Note the train/val/test datasets of the current dataset map will be
|
||||
modified in-place.
|
||||
|
||||
Args:
|
||||
other_dataset_maps: The list of dataset maps to be joined into the
|
||||
current dataset map.
|
||||
"""
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataset_list = [
|
||||
getattr(self, set_),
|
||||
*[getattr(dmap, set_) for dmap in other_dataset_maps],
|
||||
]
|
||||
dataset_list = [d for d in dataset_list if d is not None]
|
||||
if len(dataset_list) == 0:
|
||||
setattr(self, set_, None)
|
||||
continue
|
||||
d0 = dataset_list[0]
|
||||
if len(dataset_list) > 1:
|
||||
d0.join(dataset_list[1:])
|
||||
setattr(self, set_, d0)
|
||||
|
||||
|
||||
class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
|
||||
@@ -19,6 +19,8 @@ from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.eval_batch_index
|
||||
)
|
||||
|
||||
def is_filtered(self):
|
||||
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
|
||||
"""
|
||||
Join the dataset with other JsonIndexDataset objects.
|
||||
|
||||
Args:
|
||||
other_datasets: A list of JsonIndexDataset objects to be joined
|
||||
into the current dataset.
|
||||
"""
|
||||
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
|
||||
raise ValueError("This function can only join a list of JsonIndexDataset")
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots.update(
|
||||
# https://gist.github.com/treyhunner/f35292e676efa0be1728
|
||||
functools.reduce(
|
||||
lambda a, b: {**a, **b},
|
||||
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
|
||||
)
|
||||
)
|
||||
all_eval_batches = [
|
||||
self.eval_batches,
|
||||
# pyre-ignore
|
||||
*[d.eval_batches for d in other_datasets],
|
||||
]
|
||||
if not (
|
||||
all(ba is None for ba in all_eval_batches)
|
||||
or all(ba is not None for ba in all_eval_batches)
|
||||
):
|
||||
raise ValueError(
|
||||
"When joining datasets, either all joined datasets have to have their"
|
||||
" eval_batches defined, or all should have their eval batches undefined."
|
||||
)
|
||||
if self.eval_batches is not None:
|
||||
self.eval_batches = sum(all_eval_batches, [])
|
||||
self._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
def is_filtered(self) -> bool:
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||
stored on the disk might be missing in the dataset object.
|
||||
@@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
allow_missing_indices: bool = False,
|
||||
remove_missing_indices: bool = False,
|
||||
suppress_missing_index_warning: bool = True,
|
||||
) -> List[List[Union[Optional[int], int]]]:
|
||||
"""
|
||||
Obtain indices into the dataset object given a list of frame ids.
|
||||
@@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||
are not present in the dataset.
|
||||
If `True` removes missing indices from the returned indices.
|
||||
suppress_missing_index_warning:
|
||||
Active if `allow_missing_indices==True`. Suppressess a warning message
|
||||
in case an entry from `seq_frame_index` is missing in the dataset
|
||||
(expected in certain cases - e.g. when setting
|
||||
`self.remove_empty_masks=True`).
|
||||
|
||||
Returns:
|
||||
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||
@@ -254,7 +299,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
if not allow_missing_indices:
|
||||
raise IndexError(msg)
|
||||
warnings.warn(msg)
|
||||
if not suppress_missing_index_warning:
|
||||
warnings.warn(msg)
|
||||
return idx
|
||||
if path is not None:
|
||||
# Check that the loaded frame path is consistent
|
||||
@@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||
allow_missing_indices: bool = True,
|
||||
) -> "JsonIndexDataset":
|
||||
"""
|
||||
Generate a dataset subset given the list of frames specified in `frame_index`.
|
||||
|
||||
Args:
|
||||
frame_index: The list of frame indentifiers (as stored in the metadata)
|
||||
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
|
||||
Image paths relative to the dataset_root can be stored specified as well:
|
||||
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
|
||||
in the latter case, if imaga_path do not match the stored paths, an error
|
||||
is raised.
|
||||
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||
entry from `frame_index` which is missing in the dataset.
|
||||
Otherwise, generates a subset consisting of frames entries that actually
|
||||
exist in the dataset.
|
||||
"""
|
||||
# Get the indices into the frame annots.
|
||||
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||
[frame_index],
|
||||
@@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
)
|
||||
return out
|
||||
|
||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||
c2seq = defaultdict(list)
|
||||
# pyre-ignore
|
||||
for sequence_name, sa in self.seq_annots.items():
|
||||
c2seq[sa.category].append(sequence_name)
|
||||
return dict(c2seq)
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return self.eval_batches
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
)
|
||||
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||
@@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
(test frames can repeat across batches).
|
||||
|
||||
Args:
|
||||
category: The object category of the dataset.
|
||||
category: Dataset categories to load expressed as a string of comma-separated
|
||||
category names (e.g. `"apple,car,orange"`).
|
||||
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||
e.g. "manyview_dev_0", "fewview_test", ...
|
||||
dataset_root: The root folder of the dataset.
|
||||
@@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
test_on_train: bool = False
|
||||
only_test_set: bool = False
|
||||
load_eval_batches: bool = True
|
||||
num_load_workers: int = 4
|
||||
|
||||
n_known_frames_for_test: int = 0
|
||||
|
||||
@@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
frame_file = os.path.join(
|
||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||
)
|
||||
if "," in self.category:
|
||||
# a comma-separated list of categories to load
|
||||
categories = [c.strip() for c in self.category.split(",")]
|
||||
logger.info(f"Loading a list of categories: {str(categories)}.")
|
||||
with multiprocessing.Pool(
|
||||
processes=min(self.num_load_workers, len(categories))
|
||||
) as pool:
|
||||
category_dataset_maps = list(
|
||||
tqdm(
|
||||
pool.imap(self._load_category, categories),
|
||||
total=len(categories),
|
||||
)
|
||||
)
|
||||
dataset_map = category_dataset_maps[0]
|
||||
dataset_map.join(category_dataset_maps[1:])
|
||||
|
||||
else:
|
||||
# one category to load
|
||||
dataset_map = self._load_category(self.category)
|
||||
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def _load_category(self, category: str) -> DatasetMap:
|
||||
|
||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||
self.dataset_root, category, "sequence_annotations.jgz"
|
||||
)
|
||||
|
||||
path_manager = self.path_manager_factory.get()
|
||||
@@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
dataset = dataset_type(**common_dataset_kwargs)
|
||||
|
||||
available_subset_names = self._get_available_subset_names()
|
||||
available_subset_names = self._get_available_subset_names(category)
|
||||
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||
if self.subset_name not in available_subset_names:
|
||||
raise ValueError(
|
||||
@@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
|
||||
# load the list of train/val/test frames
|
||||
subset_mapping = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
||||
)
|
||||
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
if self.load_eval_batches:
|
||||
eval_batch_index = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category,
|
||||
category,
|
||||
"eval_batches",
|
||||
f"eval_batches_{self.subset_name}.json",
|
||||
)
|
||||
)
|
||||
else:
|
||||
eval_batch_index = None
|
||||
|
||||
train_dataset = None
|
||||
if not self.only_test_set:
|
||||
@@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
)
|
||||
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||
|
||||
self.dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
@@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
def _get_available_subset_names(self):
|
||||
def _get_available_subset_names(self, category: str):
|
||||
return get_available_subset_names(
|
||||
self.dataset_root,
|
||||
self.category,
|
||||
category,
|
||||
path_manager=self.path_manager_factory.get(),
|
||||
)
|
||||
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
|
||||
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
from typing import Dict, Iterable, Iterator, List, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
@@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
# same but for timestamps if they are available
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
|
||||
# if True, the sampler first reads from the dataset the mapping between
|
||||
# sequence names and their categories.
|
||||
# During batch sampling, the sampler ensures uniform distribution over the categories
|
||||
# of the sampled sequences.
|
||||
category_aware: bool = True
|
||||
|
||||
seq_names: List[str] = field(init=False)
|
||||
|
||||
category_to_sequence_names: Dict[str, List[str]] = field(init=False)
|
||||
categories: List[str] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.batch_size <= 0:
|
||||
raise ValueError(
|
||||
@@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
|
||||
self.seq_names = list(self.dataset.sequence_names())
|
||||
|
||||
if self.category_aware:
|
||||
self.category_to_sequence_names = self.dataset.category_to_sequence_names()
|
||||
self.categories = list(self.category_to_sequence_names.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_batches
|
||||
|
||||
@@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
def _sample_batch(self, batch_idx) -> List[int]:
|
||||
n_per_seq = np.random.choice(self.images_per_seq_options)
|
||||
n_seqs = -(-self.batch_size // n_per_seq) # round up
|
||||
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
|
||||
|
||||
if self.category_aware:
|
||||
# first sample categories at random, these can be repeated in the batch
|
||||
chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True)
|
||||
# then randomly sample a set of unique sequences within each category
|
||||
chosen_seq = []
|
||||
for cat, n_per_category in Counter(chosen_cat).items():
|
||||
category_chosen_seq = _capped_random_choice(
|
||||
self.category_to_sequence_names[cat],
|
||||
n_per_category,
|
||||
replace=False,
|
||||
)
|
||||
chosen_seq.extend([str(s) for s in category_chosen_seq])
|
||||
else:
|
||||
chosen_seq = _capped_random_choice(
|
||||
self.seq_names,
|
||||
n_seqs,
|
||||
replace=False,
|
||||
)
|
||||
|
||||
if self.sample_consecutive_frames:
|
||||
frame_idx = []
|
||||
|
||||
Reference in New Issue
Block a user