pluggable JsonIndexDataset

Summary: Make dataset type and args configurable on JsonIndexDatasetMapProvider.

Reviewed By: davnov134

Differential Revision: D36666705

fbshipit-source-id: 4d0a3781d9a956504f51f1c7134c04edf1eb2846
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 1d43251391
commit 6275283202
8 changed files with 117 additions and 84 deletions

View File

@ -24,12 +24,13 @@ data_source_args:
- 10
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
dataset_JsonIndexDataset_args:
load_point_clouds: false
mask_depths: false
mask_images: false
generic_model_args:
loss_weights:
loss_mask_bce: 1.0

View File

@ -20,9 +20,6 @@ data_source_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: false
task_str: multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0

View File

@ -286,24 +286,35 @@ data_source_args:
category: ???
task_str: singlesequence
dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1
test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: []
test_restrict_sequence_id: -1
assert_single_seq: false
only_test_set: false
aux_dataset_kwargs:
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
limit_to: 0
limit_sequences_to: 0
exclude_sequence: []
limit_category_to: []
load_images: true
load_depths: true
load_depth_masks: true
load_masks: true
load_point_clouds: false
max_points: 0
mask_images: false
mask_depths: false
image_height: 800
image_width: 800
box_crop: true
box_crop_mask_thr: 0.4
box_crop_context: 0.3
remove_empty_masks: true
seed: 0
sort_frames: false
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args:

View File

@ -26,8 +26,6 @@ DATA_DIR = Path(__file__).resolve().parent
DEBUG: bool = False
# TODO:
# - sort out path_manager config. Here we monkeypatch to avoid
# the problem.
# - add enough files to skateboard_first_5 that this works on RE.
# - share common code with PyTorch3D tests?
# - deal with the temporary output files this test creates
@ -54,7 +52,7 @@ class TestExperiment(unittest.TestCase):
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.limit_sequences_to = 5
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
dataloader_args.dataset_len = 1
cfg.solver_args.max_epochs = 2

View File

@ -13,7 +13,6 @@ import os
import random
import warnings
from collections import defaultdict
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from typing import (
@ -31,6 +30,7 @@ from typing import (
import numpy as np
import torch
from PIL import Image
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
@ -47,8 +47,8 @@ class FrameAnnotsEntry(TypedDict):
frame_annotation: types.FrameAnnotation
@dataclass(eq=False)
class JsonIndexDataset(DatasetBase):
@registry.register
class JsonIndexDataset(DatasetBase, ReplaceableBase):
"""
A dataset with annotations in json files like the Common Objects in 3D
(CO3D) dataset.

View File

@ -7,11 +7,14 @@
import json
import os
from dataclasses import field
from typing import Any, Dict, List, Sequence
from typing import Dict, List, Sequence, Tuple, Type
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.tools.config import (
expand_args_fields,
registry,
run_auto_creation,
)
from .dataset_map_provider import (
DatasetMap,
@ -20,6 +23,7 @@ from .dataset_map_provider import (
Task,
)
from .json_index_dataset import JsonIndexDataset
from .utils import (
DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST,
@ -28,22 +32,6 @@ from .utils import (
)
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
"default": {
"box_crop": True,
"box_crop_context": 0.3,
"image_width": 800,
"image_height": 800,
"remove_empty_masks": True,
}
}
def _make_default_config() -> DictConfig:
return DictConfig(DATASET_CONFIGS["default"])
# fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv",
@ -62,6 +50,21 @@ CO3D_CATEGORIES: List[str] = list(reversed([
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
# are not directly specified for it in the config but come from the
# DatasetMapProvider.
_NEED_CONTROL: Tuple[str, ...] = (
"dataset_root",
"eval_batches",
"n_frames_per_sequence",
"path_manager",
"pick_sequence",
"subsets",
"frame_annotations_file",
"sequence_annotations_file",
"subset_lists_file",
)
@registry.register
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
@ -73,16 +76,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
category: The object category of the dataset.
task_str: "multisequence" or "singlesequence".
dataset_root: The root folder of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence.
test_on_train: Construct validation and test datasets from
the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
@ -90,32 +87,45 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
JsonIndexDataset constructor call.
dataset_class_type: name of class (JsonIndexDataset or a subclass)
to use for the dataset.
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
to all the dataset constructors.
"""
category: str
task_str: str = "singlesequence"
dataset_root: str = _CO3D_DATASET_ROOT
limit_to: int = -1
limit_sequences_to: int = -1
n_frames_per_sequence: int = -1
test_on_train: bool = False
load_point_clouds: bool = False
mask_images: bool = False
mask_depths: bool = False
restrict_sequence_name: Sequence[str] = ()
test_restrict_sequence_id: int = -1
assert_single_seq: bool = False
only_test_set: bool = False
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
dataset: JsonIndexDataset
dataset_class_type: str = "JsonIndexDataset"
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self):
run_auto_creation(self)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args(JsonIndexDatasetMapProvider) to
not expose certain fields of each dataset class.
"""
with open_dict(args):
for key in _NEED_CONTROL:
del args[key]
def get_dataset_map(self) -> DatasetMap:
def create_dataset(self):
"""
Prevent the member named dataset from being created.
"""
return
def __post_init__(self):
super().__init__()
run_auto_creation(self)
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
@ -135,16 +145,11 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
)
common_kwargs = {
"dataset_root": self.dataset_root,
"limit_to": self.limit_to,
"limit_sequences_to": self.limit_sequences_to,
"load_point_clouds": self.load_point_clouds,
"mask_images": self.mask_images,
"mask_depths": self.mask_depths,
"path_manager": path_manager,
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
**self.aux_dataset_kwargs,
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
}
# This maps the common names of the dataset subsets ("train"/"val"/"test")
@ -204,9 +209,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
dataset_type: Type[JsonIndexDataset] = registry.get(
JsonIndexDataset, self.dataset_class_type
)
expand_args_fields(dataset_type)
train_dataset = None
if not self.only_test_set:
train_dataset = JsonIndexDataset(
train_dataset = dataset_type(
n_frames_per_sequence=self.n_frames_per_sequence,
subsets=set_names_mapping["train"],
pick_sequence=restrict_sequence_name,
@ -216,13 +225,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
val_dataset = JsonIndexDataset(
val_dataset = dataset_type(
n_frames_per_sequence=-1,
subsets=set_names_mapping["val"],
pick_sequence=restrict_sequence_name,
**common_kwargs,
)
test_dataset = JsonIndexDataset(
test_dataset = dataset_type(
n_frames_per_sequence=-1,
subsets=set_names_mapping["test"],
pick_sequence=restrict_sequence_name,
@ -235,19 +244,25 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index
)
datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)
if self.assert_single_seq:
# check there's only one sequence in all datasets
sequence_names = {
sequence_name
for dset in datasets.iter_datasets()
for dset in dataset_map.iter_datasets()
for sequence_name in dset.sequence_names()
}
if len(sequence_names) > 1:
raise ValueError("Multiple sequences loaded but expected one")
return datasets
self.dataset_map = dataset_map
def get_dataset_map(self) -> DatasetMap:
# pyre-ignore[16]
return self.dataset_map
def get_task(self) -> Task:
return Task(self.task_str)

View File

@ -106,8 +106,8 @@ def evaluate_dbir_for_category(
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
"task_str": task.value,
"test_on_train": False,
"load_point_clouds": True,
"test_restrict_sequence_id": single_sequence_id,
"dataset_JsonIndexDataset_args": {"load_point_clouds": True},
}
data_source = ImplicitronDataSource(
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args

View File

@ -4,24 +4,35 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1
test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: []
test_restrict_sequence_id: -1
assert_single_seq: false
only_test_set: false
aux_dataset_kwargs:
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
limit_to: 0
limit_sequences_to: 0
exclude_sequence: []
limit_category_to: []
load_images: true
load_depths: true
load_depth_masks: true
load_masks: true
load_point_clouds: false
max_points: 0
mask_images: false
mask_depths: false
image_height: 800
image_width: 800
box_crop: true
box_crop_mask_thr: 0.4
box_crop_context: 0.3
remove_empty_masks: true
seed: 0
sort_frames: false
path_manager_factory_PathManagerFactory_args:
silence_logs: true
data_loader_map_provider_SequenceDataLoaderMapProvider_args: