mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
pluggable JsonIndexDataset
Summary: Make dataset type and args configurable on JsonIndexDatasetMapProvider. Reviewed By: davnov134 Differential Revision: D36666705 fbshipit-source-id: 4d0a3781d9a956504f51f1c7134c04edf1eb2846
This commit is contained in:
parent
1d43251391
commit
6275283202
@ -24,12 +24,13 @@ data_source_args:
|
||||
- 10
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
dataset_JsonIndexDataset_args:
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
generic_model_args:
|
||||
loss_weights:
|
||||
loss_mask_bce: 1.0
|
||||
|
@ -20,9 +20,6 @@ data_source_args:
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
assert_single_seq: false
|
||||
task_str: multisequence
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
|
@ -286,24 +286,35 @@ data_source_args:
|
||||
category: ???
|
||||
task_str: singlesequence
|
||||
dataset_root: ''
|
||||
limit_to: -1
|
||||
limit_sequences_to: -1
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: false
|
||||
load_point_clouds: false
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
restrict_sequence_name: []
|
||||
test_restrict_sequence_id: -1
|
||||
assert_single_seq: false
|
||||
only_test_set: false
|
||||
aux_dataset_kwargs:
|
||||
box_crop: true
|
||||
box_crop_context: 0.3
|
||||
image_width: 800
|
||||
image_height: 800
|
||||
remove_empty_masks: true
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
|
@ -26,8 +26,6 @@ DATA_DIR = Path(__file__).resolve().parent
|
||||
DEBUG: bool = False
|
||||
|
||||
# TODO:
|
||||
# - sort out path_manager config. Here we monkeypatch to avoid
|
||||
# the problem.
|
||||
# - add enough files to skateboard_first_5 that this works on RE.
|
||||
# - share common code with PyTorch3D tests?
|
||||
# - deal with the temporary output files this test creates
|
||||
@ -54,7 +52,7 @@ class TestExperiment(unittest.TestCase):
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.limit_sequences_to = 5
|
||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
||||
dataloader_args.dataset_len = 1
|
||||
cfg.solver_args.max_epochs = 2
|
||||
|
||||
|
@ -13,7 +13,6 @@ import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@ -31,6 +30,7 @@ from typing import (
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
@ -47,8 +47,8 @@ class FrameAnnotsEntry(TypedDict):
|
||||
frame_annotation: types.FrameAnnotation
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class JsonIndexDataset(DatasetBase):
|
||||
@registry.register
|
||||
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
"""
|
||||
A dataset with annotations in json files like the Common Objects in 3D
|
||||
(CO3D) dataset.
|
||||
|
@ -7,11 +7,14 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Sequence
|
||||
from typing import Dict, List, Sequence, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
@ -20,6 +23,7 @@ from .dataset_map_provider import (
|
||||
Task,
|
||||
)
|
||||
from .json_index_dataset import JsonIndexDataset
|
||||
|
||||
from .utils import (
|
||||
DATASET_TYPE_KNOWN,
|
||||
DATASET_TYPE_TEST,
|
||||
@ -28,22 +32,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
# TODO from dataset.dataset_configs import DATASET_CONFIGS
|
||||
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
"default": {
|
||||
"box_crop": True,
|
||||
"box_crop_context": 0.3,
|
||||
"image_width": 800,
|
||||
"image_height": 800,
|
||||
"remove_empty_masks": True,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_default_config() -> DictConfig:
|
||||
return DictConfig(DATASET_CONFIGS["default"])
|
||||
|
||||
|
||||
# fmt: off
|
||||
CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
||||
@ -62,6 +50,21 @@ CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
|
||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
|
||||
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
|
||||
# are not directly specified for it in the config but come from the
|
||||
# DatasetMapProvider.
|
||||
_NEED_CONTROL: Tuple[str, ...] = (
|
||||
"dataset_root",
|
||||
"eval_batches",
|
||||
"n_frames_per_sequence",
|
||||
"path_manager",
|
||||
"pick_sequence",
|
||||
"subsets",
|
||||
"frame_annotations_file",
|
||||
"sequence_annotations_file",
|
||||
"subset_lists_file",
|
||||
)
|
||||
|
||||
|
||||
@registry.register
|
||||
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
@ -73,16 +76,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
category: The object category of the dataset.
|
||||
task_str: "multisequence" or "singlesequence".
|
||||
dataset_root: The root folder of the dataset.
|
||||
limit_to: Limit the dataset to the first #limit_to frames.
|
||||
limit_sequences_to: Limit the dataset to the first
|
||||
#limit_sequences_to sequences.
|
||||
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
|
||||
in each sequence.
|
||||
test_on_train: Construct validation and test datasets from
|
||||
the training subset.
|
||||
load_point_clouds: Enable returning scene point clouds from the dataset.
|
||||
mask_images: Mask the loaded images with segmentation masks.
|
||||
mask_depths: Mask the loaded depths with segmentation masks.
|
||||
restrict_sequence_name: Restrict the dataset sequences to the ones
|
||||
present in the given list of names.
|
||||
test_restrict_sequence_id: The ID of the loaded sequence.
|
||||
@ -90,32 +87,45 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
assert_single_seq: Assert that only frames from a single sequence
|
||||
are present in all generated datasets.
|
||||
only_test_set: Load only the test set.
|
||||
aux_dataset_kwargs: Specifies additional arguments to the
|
||||
JsonIndexDataset constructor call.
|
||||
dataset_class_type: name of class (JsonIndexDataset or a subclass)
|
||||
to use for the dataset.
|
||||
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
|
||||
to all the dataset constructors.
|
||||
"""
|
||||
|
||||
category: str
|
||||
task_str: str = "singlesequence"
|
||||
dataset_root: str = _CO3D_DATASET_ROOT
|
||||
limit_to: int = -1
|
||||
limit_sequences_to: int = -1
|
||||
n_frames_per_sequence: int = -1
|
||||
test_on_train: bool = False
|
||||
load_point_clouds: bool = False
|
||||
mask_images: bool = False
|
||||
mask_depths: bool = False
|
||||
restrict_sequence_name: Sequence[str] = ()
|
||||
test_restrict_sequence_id: int = -1
|
||||
assert_single_seq: bool = False
|
||||
only_test_set: bool = False
|
||||
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
|
||||
dataset: JsonIndexDataset
|
||||
dataset_class_type: str = "JsonIndexDataset"
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||
not expose certain fields of each dataset class.
|
||||
"""
|
||||
with open_dict(args):
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
def create_dataset(self):
|
||||
"""
|
||||
Prevent the member named dataset from being created.
|
||||
"""
|
||||
return
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
@ -135,16 +145,11 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
)
|
||||
common_kwargs = {
|
||||
"dataset_root": self.dataset_root,
|
||||
"limit_to": self.limit_to,
|
||||
"limit_sequences_to": self.limit_sequences_to,
|
||||
"load_point_clouds": self.load_point_clouds,
|
||||
"mask_images": self.mask_images,
|
||||
"mask_depths": self.mask_depths,
|
||||
"path_manager": path_manager,
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
**self.aux_dataset_kwargs,
|
||||
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
|
||||
}
|
||||
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
@ -204,9 +209,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
# overwrite the restrict_sequence_name
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
|
||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||
JsonIndexDataset, self.dataset_class_type
|
||||
)
|
||||
expand_args_fields(dataset_type)
|
||||
train_dataset = None
|
||||
if not self.only_test_set:
|
||||
train_dataset = JsonIndexDataset(
|
||||
train_dataset = dataset_type(
|
||||
n_frames_per_sequence=self.n_frames_per_sequence,
|
||||
subsets=set_names_mapping["train"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
@ -216,13 +225,13 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
val_dataset = JsonIndexDataset(
|
||||
val_dataset = dataset_type(
|
||||
n_frames_per_sequence=-1,
|
||||
subsets=set_names_mapping["val"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
**common_kwargs,
|
||||
)
|
||||
test_dataset = JsonIndexDataset(
|
||||
test_dataset = dataset_type(
|
||||
n_frames_per_sequence=-1,
|
||||
subsets=set_names_mapping["test"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
@ -235,19 +244,25 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index
|
||||
)
|
||||
datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
||||
if self.assert_single_seq:
|
||||
# check there's only one sequence in all datasets
|
||||
sequence_names = {
|
||||
sequence_name
|
||||
for dset in datasets.iter_datasets()
|
||||
for dset in dataset_map.iter_datasets()
|
||||
for sequence_name in dset.sequence_names()
|
||||
}
|
||||
if len(sequence_names) > 1:
|
||||
raise ValueError("Multiple sequences loaded but expected one")
|
||||
|
||||
return datasets
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
# pyre-ignore[16]
|
||||
return self.dataset_map
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
@ -106,8 +106,8 @@ def evaluate_dbir_for_category(
|
||||
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
|
||||
"task_str": task.value,
|
||||
"test_on_train": False,
|
||||
"load_point_clouds": True,
|
||||
"test_restrict_sequence_id": single_sequence_id,
|
||||
"dataset_JsonIndexDataset_args": {"load_point_clouds": True},
|
||||
}
|
||||
data_source = ImplicitronDataSource(
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
|
||||
|
@ -4,24 +4,35 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
category: ???
|
||||
task_str: singlesequence
|
||||
dataset_root: ''
|
||||
limit_to: -1
|
||||
limit_sequences_to: -1
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: false
|
||||
load_point_clouds: false
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
restrict_sequence_name: []
|
||||
test_restrict_sequence_id: -1
|
||||
assert_single_seq: false
|
||||
only_test_set: false
|
||||
aux_dataset_kwargs:
|
||||
box_crop: true
|
||||
box_crop_context: 0.3
|
||||
image_width: 800
|
||||
image_height: 800
|
||||
remove_empty_masks: true
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
|
Loading…
x
Reference in New Issue
Block a user