CO3Dv2 multi-category extension

Summary:
Allows loading of multiple categories.
Multiple categories are provided in a comma-separated list of category names.

Reviewed By: bottler, shapovalov

Differential Revision: D40803297

fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
David Novotny 2022-11-02 13:55:25 -07:00 committed by Facebook GitHub Bot
parent c54e048666
commit e4a3298149
9 changed files with 272 additions and 25 deletions

View File

@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
num_load_workers: 4
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory

View File

@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
"""
raise ValueError("This dataset does not contain videos.")
def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
"""
Joins the current dataset with a list of other datasets of the same type.
"""
raise NotImplementedError()
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
# pyre-ignore[16]
return self._seq_to_idx.keys()
def category_to_sequence_names(self) -> Dict[str, List[str]]:
"""
Returns a dict mapping from each dataset category to a list of its
sequence names.
Returns:
category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
"""
c2seq = defaultdict(list)
for sequence_name in self.sequence_names():
first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
# crashes without overriding __getitem__
sequence_category = self[first_frame_idx].sequence_category
c2seq[sequence_category].append(sequence_name)
return dict(c2seq)
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:

View File

@ -7,7 +7,7 @@
import logging
import os
from dataclasses import dataclass
from typing import Iterator, Optional
from typing import Iterable, Iterator, Optional
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
@ -51,6 +51,34 @@ class DatasetMap:
if self.test is not None:
yield self.test
def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
"""
Joins the current DatasetMap with other dataset maps from the input list.
For each subset of each dataset map (train/val/test), the function
omits joining the subsets that are None.
Note the train/val/test datasets of the current dataset map will be
modified in-place.
Args:
other_dataset_maps: The list of dataset maps to be joined into the
current dataset map.
"""
for set_ in ["train", "val", "test"]:
dataset_list = [
getattr(self, set_),
*[getattr(dmap, set_) for dmap in other_dataset_maps],
]
dataset_list = [d for d in dataset_list if d is not None]
if len(dataset_list) == 0:
setattr(self, set_, None)
continue
d0 = dataset_list[0]
if len(dataset_list) > 1:
d0.join(dataset_list[1:])
setattr(self, set_, d0)
class DatasetMapProviderBase(ReplaceableBase):
"""

View File

@ -19,6 +19,8 @@ from pathlib import Path
from typing import (
Any,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
self.eval_batch_index
)
def is_filtered(self):
def join(self, other_datasets: Iterable[DatasetBase]) -> None:
"""
Join the dataset with other JsonIndexDataset objects.
Args:
other_datasets: A list of JsonIndexDataset objects to be joined
into the current dataset.
"""
if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
raise ValueError("This function can only join a list of JsonIndexDataset")
# pyre-ignore[16]
self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
# pyre-ignore[16]
self.seq_annots.update(
# https://gist.github.com/treyhunner/f35292e676efa0be1728
functools.reduce(
lambda a, b: {**a, **b},
[d.seq_annots for d in other_datasets], # pyre-ignore[16]
)
)
all_eval_batches = [
self.eval_batches,
# pyre-ignore
*[d.eval_batches for d in other_datasets],
]
if not (
all(ba is None for ba in all_eval_batches)
or all(ba is not None for ba in all_eval_batches)
):
raise ValueError(
"When joining datasets, either all joined datasets have to have their"
" eval_batches defined, or all should have their eval batches undefined."
)
if self.eval_batches is not None:
self.eval_batches = sum(all_eval_batches, [])
self._invalidate_indexes(filter_seq_annots=True)
def is_filtered(self) -> bool:
"""
Returns `True` in case the dataset has been filtered and thus some frame annotations
stored on the disk might be missing in the dataset object.
@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
allow_missing_indices: bool = False,
remove_missing_indices: bool = False,
suppress_missing_index_warning: bool = True,
) -> List[List[Union[Optional[int], int]]]:
"""
Obtain indices into the dataset object given a list of frame ids.
@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
If `False`, returns `None` in place of `seq_frame_index` entries that
are not present in the dataset.
If `True` removes missing indices from the returned indices.
suppress_missing_index_warning:
Active if `allow_missing_indices==True`. Suppressess a warning message
in case an entry from `seq_frame_index` is missing in the dataset
(expected in certain cases - e.g. when setting
`self.remove_empty_masks=True`).
Returns:
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
@ -254,7 +299,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
)
if not allow_missing_indices:
raise IndexError(msg)
warnings.warn(msg)
if not suppress_missing_index_warning:
warnings.warn(msg)
return idx
if path is not None:
# Check that the loaded frame path is consistent
@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
allow_missing_indices: bool = True,
) -> "JsonIndexDataset":
"""
Generate a dataset subset given the list of frames specified in `frame_index`.
Args:
frame_index: The list of frame indentifiers (as stored in the metadata)
specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
Image paths relative to the dataset_root can be stored specified as well:
`List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
in the latter case, if imaga_path do not match the stored paths, an error
is raised.
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
entry from `frame_index` which is missing in the dataset.
Otherwise, generates a subset consisting of frames entries that actually
exist in the dataset.
"""
# Get the indices into the frame annots.
dataset_indices = self.seq_frame_index_to_dataset_index(
[frame_index],
@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
)
return out
def category_to_sequence_names(self) -> Dict[str, List[str]]:
c2seq = defaultdict(list)
# pyre-ignore
for sequence_name, sa in self.seq_annots.items():
c2seq[sa.category].append(sequence_name)
return dict(c2seq)
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches

View File

@ -8,6 +8,7 @@
import copy
import json
import logging
import multiprocessing
import os
import warnings
from collections import defaultdict
@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
)
from pytorch3d.renderer.cameras import CamerasBase
from tqdm import tqdm
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
(test frames can repeat across batches).
Args:
category: The object category of the dataset.
category: Dataset categories to load expressed as a string of comma-separated
category names (e.g. `"apple,car,orange"`).
subset_name: The name of the dataset subset. For CO3Dv2, these include
e.g. "manyview_dev_0", "fewview_test", ...
dataset_root: The root folder of the dataset.
@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
test_on_train: bool = False
only_test_set: bool = False
load_eval_batches: bool = True
num_load_workers: int = 4
n_known_frames_for_test: int = 0
@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
frame_file = os.path.join(
self.dataset_root, self.category, "frame_annotations.jgz"
)
if "," in self.category:
# a comma-separated list of categories to load
categories = [c.strip() for c in self.category.split(",")]
logger.info(f"Loading a list of categories: {str(categories)}.")
with multiprocessing.Pool(
processes=min(self.num_load_workers, len(categories))
) as pool:
category_dataset_maps = list(
tqdm(
pool.imap(self._load_category, categories),
total=len(categories),
)
)
dataset_map = category_dataset_maps[0]
dataset_map.join(category_dataset_maps[1:])
else:
# one category to load
dataset_map = self._load_category(self.category)
self.dataset_map = dataset_map
def _load_category(self, category: str) -> DatasetMap:
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(
self.dataset_root, self.category, "sequence_annotations.jgz"
self.dataset_root, category, "sequence_annotations.jgz"
)
path_manager = self.path_manager_factory.get()
@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
dataset = dataset_type(**common_dataset_kwargs)
available_subset_names = self._get_available_subset_names()
available_subset_names = self._get_available_subset_names(category)
logger.debug(f"Available subset names: {str(available_subset_names)}.")
if self.subset_name not in available_subset_names:
raise ValueError(
@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
# load the list of train/val/test frames
subset_mapping = self._load_annotation_json(
os.path.join(
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
)
os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
)
# load the evaluation batches
if self.load_eval_batches:
eval_batch_index = self._load_annotation_json(
os.path.join(
self.category,
category,
"eval_batches",
f"eval_batches_{self.subset_name}.json",
)
)
else:
eval_batch_index = None
train_dataset = None
if not self.only_test_set:
@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
)
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
self.dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
data = json.load(f)
return data
def _get_available_subset_names(self):
def _get_available_subset_names(self, category: str):
return get_available_subset_names(
self.dataset_root,
self.category,
category,
path_manager=self.path_manager_factory.get(),
)

View File

@ -6,8 +6,9 @@
import warnings
from collections import Counter
from dataclasses import dataclass, field
from typing import Iterable, Iterator, List, Sequence, Tuple
from typing import Dict, Iterable, Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]):
# same but for timestamps if they are available
consecutive_frames_max_gap_seconds: float = 0.1
# if True, the sampler first reads from the dataset the mapping between
# sequence names and their categories.
# During batch sampling, the sampler ensures uniform distribution over the categories
# of the sampled sequences.
category_aware: bool = True
seq_names: List[str] = field(init=False)
category_to_sequence_names: Dict[str, List[str]] = field(init=False)
categories: List[str] = field(init=False)
def __post_init__(self) -> None:
if self.batch_size <= 0:
raise ValueError(
@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]):
self.seq_names = list(self.dataset.sequence_names())
if self.category_aware:
self.category_to_sequence_names = self.dataset.category_to_sequence_names()
self.categories = list(self.category_to_sequence_names.keys())
def __len__(self) -> int:
return self.num_batches
@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]):
def _sample_batch(self, batch_idx) -> List[int]:
n_per_seq = np.random.choice(self.images_per_seq_options)
n_seqs = -(-self.batch_size // n_per_seq) # round up
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
if self.category_aware:
# first sample categories at random, these can be repeated in the batch
chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True)
# then randomly sample a set of unique sequences within each category
chosen_seq = []
for cat, n_per_category in Counter(chosen_cat).items():
category_chosen_seq = _capped_random_choice(
self.category_to_sequence_names[cat],
n_per_category,
replace=False,
)
chosen_seq.extend([str(s) for s in category_chosen_seq])
else:
chosen_seq = _capped_random_choice(
self.seq_names,
n_seqs,
replace=False,
)
if self.sample_consecutive_frames:
frame_idx = []

View File

@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
num_load_workers: 4
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory

View File

@ -11,17 +11,20 @@ from dataclasses import dataclass
from itertools import product
import numpy as np
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
sequence_name: str = "sequence"
frame_timestamp: float = 0.0
@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
for seq_name, idx in self._seq_to_idx.items():
for i in idx:
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
def get_frame_numbers_and_timestamps(self, idxs):
out = []
@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
)
return out
def __getitem__(self, index: int):
fa = self.frame_annots[index]["frame_annotation"]
fd = FrameData(
sequence_name=fa.sequence_name,
sequence_category="default_category",
frame_number=torch.LongTensor([fa.frame_number]),
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
)
return fd
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):

View File

@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
categories = ["A", "B"]
subset_name = "test"
eval_batch_size = 5
n_frames = 8 * 3
n_sequences = 5
n_eval_batches = 10
with tempfile.TemporaryDirectory() as tmpd:
_make_random_json_dataset_map_provider_v2_data(
tmpd,
categories,
eval_batch_size=eval_batch_size,
n_frames=n_frames,
n_sequences=n_sequences,
n_eval_batches=n_eval_batches,
)
for n_known_frames_for_test in [0, 2]:
for category in categories:
dataset_provider = JsonIndexDatasetMapProviderV2(
dataset_providers = {
category: JsonIndexDatasetMapProviderV2(
category=category,
subset_name="test",
dataset_root=tmpd,
n_known_frames_for_test=n_known_frames_for_test,
)
for category in [*categories, ",".join(sorted(categories))]
}
for category, dataset_provider in dataset_providers.items():
dataset_map = dataset_provider.get_dataset_map()
for set_ in ["train", "val", "test"]:
dataset = getattr(dataset_map, set_)
cat2seq = dataset.category_to_sequence_names()
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
if not (n_known_frames_for_test != 0 and set_ == "test"):
# check the lengths only in case we do not have the
# n_known_frames_for_test set
expected_dataset_len = n_frames * n_sequences // 3
if "," in category:
# multicategory json index dataset, sum the lengths of
# category-specific ones
expected_dataset_len = sum(
len(
getattr(
dataset_providers[c].get_dataset_map(), set_
)
)
for c in categories
)
self.assertEqual(
sum(len(s) for s in cat2seq.values()),
n_sequences * len(categories),
)
self.assertEqual(len(cat2seq), len(categories))
else:
self.assertEqual(
len(cat2seq[category]),
n_sequences,
)
self.assertEqual(len(cat2seq), 1)
self.assertEqual(len(dataset), expected_dataset_len)
if set_ == "test":
# check the number of eval batches
expected_n_eval_batches = n_eval_batches
if "," in category:
expected_n_eval_batches *= len(categories)
self.assertTrue(
len(dataset.get_eval_batches())
== expected_n_eval_batches
)
if set_ in ["train", "val"]:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
dataset_provider.get_category_to_subset_name_list()
)
category_to_subset_list_ = {c: [subset_name] for c in categories}
self.assertTrue(category_to_subset_list == category_to_subset_list_)
@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
categories: List[str],
n_frames: int = 8,
n_sequences: int = 5,
n_eval_batches: int = 10,
H: int = 50,
W: int = 30,
subset_name: str = "test",
@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
sequence_annotations = []
frame_index = []
for seq_i in range(n_sequences):
seq_name = str(seq_i)
seq_name = category + str(seq_i)
for i in range(n_frames):
# generate and store image
imdir = os.path.join(root, category, seq_name, "images")
@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
json.dump(set_list, f)
eval_batches = [
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
random.sample(test_frame_index, eval_batch_size)
for _ in range(n_eval_batches)
]
eval_b_dir = os.path.join(root, category, "eval_batches")