pluggable JsonIndexDataset

Summary: Make dataset type and args configurable on JsonIndexDatasetMapProvider.

Reviewed By: davnov134

Differential Revision: D36666705

fbshipit-source-id: 4d0a3781d9a956504f51f1c7134c04edf1eb2846
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 1d43251391
commit 6275283202
8 changed files with 117 additions and 84 deletions

View File

@ -24,12 +24,13 @@ data_source_args:
- 10 - 10
dataset_map_provider_JsonIndexDatasetMapProvider_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT} dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_on_train: true test_on_train: true
test_restrict_sequence_id: 0 test_restrict_sequence_id: 0
dataset_JsonIndexDataset_args:
load_point_clouds: false
mask_depths: false
mask_images: false
generic_model_args: generic_model_args:
loss_weights: loss_weights:
loss_mask_bce: 1.0 loss_mask_bce: 1.0

View File

@ -20,9 +20,6 @@ data_source_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: false assert_single_seq: false
task_str: multisequence task_str: multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_on_train: true test_on_train: true
test_restrict_sequence_id: 0 test_restrict_sequence_id: 0

View File

@ -286,24 +286,35 @@ data_source_args:
category: ??? category: ???
task_str: singlesequence task_str: singlesequence
dataset_root: '' dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_on_train: false test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: [] restrict_sequence_name: []
test_restrict_sequence_id: -1 test_restrict_sequence_id: -1
assert_single_seq: false assert_single_seq: false
only_test_set: false only_test_set: false
aux_dataset_kwargs: dataset_class_type: JsonIndexDataset
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
path_manager_factory_class_type: PathManagerFactory path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
limit_to: 0
limit_sequences_to: 0
exclude_sequence: []
limit_category_to: []
load_images: true
load_depths: true
load_depth_masks: true
load_masks: true
load_point_clouds: false
max_points: 0
mask_images: false
mask_depths: false
image_height: 800
image_width: 800
box_crop: true
box_crop_mask_thr: 0.4
box_crop_context: 0.3
remove_empty_masks: true
seed: 0
sort_frames: false
path_manager_factory_PathManagerFactory_args: path_manager_factory_PathManagerFactory_args:
silence_logs: true silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:

View File

@ -26,8 +26,6 @@ DATA_DIR = Path(__file__).resolve().parent
DEBUG: bool = False DEBUG: bool = False
# TODO: # TODO:
# - sort out path_manager config. Here we monkeypatch to avoid
# the problem.
# - add enough files to skateboard_first_5 that this works on RE. # - add enough files to skateboard_first_5 that this works on RE.
# - share common code with PyTorch3D tests? # - share common code with PyTorch3D tests?
# - deal with the temporary output files this test creates # - deal with the temporary output files this test creates
@ -54,7 +52,7 @@ class TestExperiment(unittest.TestCase):
dataset_args.category = "skateboard" dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0 dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted" dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.limit_sequences_to = 5 dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
dataloader_args.dataset_len = 1 dataloader_args.dataset_len = 1
cfg.solver_args.max_epochs = 2 cfg.solver_args.max_epochs = 2

View File

@ -13,7 +13,6 @@ import os
import random import random
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
@ -31,6 +30,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.io import IO from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
@ -47,8 +47,8 @@ class FrameAnnotsEntry(TypedDict):
frame_annotation: types.FrameAnnotation frame_annotation: types.FrameAnnotation
@dataclass(eq=False) @registry.register
class JsonIndexDataset(DatasetBase): class JsonIndexDataset(DatasetBase, ReplaceableBase):
""" """
A dataset with annotations in json files like the Common Objects in 3D A dataset with annotations in json files like the Common Objects in 3D
(CO3D) dataset. (CO3D) dataset.

View File

@ -7,11 +7,14 @@
import json import json
import os import os
from dataclasses import field from typing import Dict, List, Sequence, Tuple, Type
from typing import Any, Dict, List, Sequence
from omegaconf import DictConfig from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import registry, run_auto_creation from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from .dataset_map_provider import ( from .dataset_map_provider import (
DatasetMap, DatasetMap,
@ -20,6 +23,7 @@ from .dataset_map_provider import (
Task, Task,
) )
from .json_index_dataset import JsonIndexDataset from .json_index_dataset import JsonIndexDataset
from .utils import ( from .utils import (
DATASET_TYPE_KNOWN, DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST, DATASET_TYPE_TEST,
@ -28,22 +32,6 @@ from .utils import (
) )
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
"default": {
"box_crop": True,
"box_crop_context": 0.3,
"image_width": 800,
"image_height": 800,
"remove_empty_masks": True,
}
}
def _make_default_config() -> DictConfig:
return DictConfig(DATASET_CONFIGS["default"])
# fmt: off # fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([ CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv", "baseballbat", "banana", "bicycle", "microwave", "tv",
@ -62,6 +50,21 @@ CO3D_CATEGORIES: List[str] = list(reversed([
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "") _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
# are not directly specified for it in the config but come from the
# DatasetMapProvider.
_NEED_CONTROL: Tuple[str, ...] = (
"dataset_root",
"eval_batches",
"n_frames_per_sequence",
"path_manager",
"pick_sequence",
"subsets",
"frame_annotations_file",
"sequence_annotations_file",
"subset_lists_file",
)
@registry.register @registry.register
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
@ -73,16 +76,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
category: The object category of the dataset. category: The object category of the dataset.
task_str: "multisequence" or "singlesequence". task_str: "multisequence" or "singlesequence".
dataset_root: The root folder of the dataset. dataset_root: The root folder of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence. in each sequence.
test_on_train: Construct validation and test datasets from test_on_train: Construct validation and test datasets from
the training subset. the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names. present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence. test_restrict_sequence_id: The ID of the loaded sequence.
@ -90,32 +87,45 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
assert_single_seq: Assert that only frames from a single sequence assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets. are present in all generated datasets.
only_test_set: Load only the test set. only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the dataset_class_type: name of class (JsonIndexDataset or a subclass)
JsonIndexDataset constructor call. to use for the dataset.
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
to all the dataset constructors.
""" """
category: str category: str
task_str: str = "singlesequence" task_str: str = "singlesequence"
dataset_root: str = _CO3D_DATASET_ROOT dataset_root: str = _CO3D_DATASET_ROOT
limit_to: int = -1
limit_sequences_to: int = -1
n_frames_per_sequence: int = -1 n_frames_per_sequence: int = -1
test_on_train: bool = False test_on_train: bool = False
load_point_clouds: bool = False
mask_images: bool = False
mask_depths: bool = False
restrict_sequence_name: Sequence[str] = () restrict_sequence_name: Sequence[str] = ()
test_restrict_sequence_id: int = -1 test_restrict_sequence_id: int = -1
assert_single_seq: bool = False assert_single_seq: bool = False
only_test_set: bool = False only_test_set: bool = False
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config) dataset: JsonIndexDataset
dataset_class_type: str = "JsonIndexDataset"
path_manager_factory: PathManagerFactory path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory" path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self): @classmethod
run_auto_creation(self) def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args(JsonIndexDatasetMapProvider) to
not expose certain fields of each dataset class.
"""
with open_dict(args):
for key in _NEED_CONTROL:
del args[key]
def get_dataset_map(self) -> DatasetMap: def create_dataset(self):
"""
Prevent the member named dataset from being created.
"""
return
def __post_init__(self):
super().__init__()
run_auto_creation(self)
if self.only_test_set and self.test_on_train: if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train") raise ValueError("Cannot have only_test_set and test_on_train")
@ -135,16 +145,11 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
) )
common_kwargs = { common_kwargs = {
"dataset_root": self.dataset_root, "dataset_root": self.dataset_root,
"limit_to": self.limit_to,
"limit_sequences_to": self.limit_sequences_to,
"load_point_clouds": self.load_point_clouds,
"mask_images": self.mask_images,
"mask_depths": self.mask_depths,
"path_manager": path_manager, "path_manager": path_manager,
"frame_annotations_file": frame_file, "frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file, "sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file, "subset_lists_file": subset_lists_file,
**self.aux_dataset_kwargs, **getattr(self, f"dataset_{self.dataset_class_type}_args"),
} }
# This maps the common names of the dataset subsets ("train"/"val"/"test") # This maps the common names of the dataset subsets ("train"/"val"/"test")
@ -204,9 +209,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# overwrite the restrict_sequence_name # overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name] restrict_sequence_name = [eval_sequence_name]
dataset_type: Type[JsonIndexDataset] = registry.get(
JsonIndexDataset, self.dataset_class_type
)
expand_args_fields(dataset_type)
train_dataset = None train_dataset = None
if not self.only_test_set: if not self.only_test_set:
train_dataset = JsonIndexDataset( train_dataset = dataset_type(
n_frames_per_sequence=self.n_frames_per_sequence, n_frames_per_sequence=self.n_frames_per_sequence,
subsets=set_names_mapping["train"], subsets=set_names_mapping["train"],
pick_sequence=restrict_sequence_name, pick_sequence=restrict_sequence_name,
@ -216,13 +225,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
assert train_dataset is not None assert train_dataset is not None
val_dataset = test_dataset = train_dataset val_dataset = test_dataset = train_dataset
else: else:
val_dataset = JsonIndexDataset( val_dataset = dataset_type(
n_frames_per_sequence=-1, n_frames_per_sequence=-1,
subsets=set_names_mapping["val"], subsets=set_names_mapping["val"],
pick_sequence=restrict_sequence_name, pick_sequence=restrict_sequence_name,
**common_kwargs, **common_kwargs,
) )
test_dataset = JsonIndexDataset( test_dataset = dataset_type(
n_frames_per_sequence=-1, n_frames_per_sequence=-1,
subsets=set_names_mapping["test"], subsets=set_names_mapping["test"],
pick_sequence=restrict_sequence_name, pick_sequence=restrict_sequence_name,
@ -235,19 +244,25 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index( test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index eval_batch_index
) )
datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset) dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
if self.assert_single_seq: if self.assert_single_seq:
# check there's only one sequence in all datasets # check there's only one sequence in all datasets
sequence_names = { sequence_names = {
sequence_name sequence_name
for dset in datasets.iter_datasets() for dset in dataset_map.iter_datasets()
for sequence_name in dset.sequence_names() for sequence_name in dset.sequence_names()
} }
if len(sequence_names) > 1: if len(sequence_names) > 1:
raise ValueError("Multiple sequences loaded but expected one") raise ValueError("Multiple sequences loaded but expected one")
return datasets self.dataset_map = dataset_map
def get_dataset_map(self) -> DatasetMap:
# pyre-ignore[16]
return self.dataset_map
def get_task(self) -> Task: def get_task(self) -> Task:
return Task(self.task_str) return Task(self.task_str)

View File

@ -106,8 +106,8 @@ def evaluate_dbir_for_category(
"assert_single_seq": task == Task.SINGLE_SEQUENCE, "assert_single_seq": task == Task.SINGLE_SEQUENCE,
"task_str": task.value, "task_str": task.value,
"test_on_train": False, "test_on_train": False,
"load_point_clouds": True,
"test_restrict_sequence_id": single_sequence_id, "test_restrict_sequence_id": single_sequence_id,
"dataset_JsonIndexDataset_args": {"load_point_clouds": True},
} }
data_source = ImplicitronDataSource( data_source = ImplicitronDataSource(
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args

View File

@ -4,24 +4,35 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ??? category: ???
task_str: singlesequence task_str: singlesequence
dataset_root: '' dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_on_train: false test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: [] restrict_sequence_name: []
test_restrict_sequence_id: -1 test_restrict_sequence_id: -1
assert_single_seq: false assert_single_seq: false
only_test_set: false only_test_set: false
aux_dataset_kwargs: dataset_class_type: JsonIndexDataset
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
path_manager_factory_class_type: PathManagerFactory path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
limit_to: 0
limit_sequences_to: 0
exclude_sequence: []
limit_category_to: []
load_images: true
load_depths: true
load_depth_masks: true
load_masks: true
load_point_clouds: false
max_points: 0
mask_images: false
mask_depths: false
image_height: 800
image_width: 800
box_crop: true
box_crop_mask_thr: 0.4
box_crop_context: 0.3
remove_empty_masks: true
seed: 0
sort_frames: false
path_manager_factory_PathManagerFactory_args: path_manager_factory_PathManagerFactory_args:
silence_logs: true silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: