mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
dataset_map_provider
Summary: replace dataset_zoo with a pluggable DatasetMapProvider. The logic is now in annotated_file_dataset_map_provider. Reviewed By: shapovalov Differential Revision: D36443965 fbshipit-source-id: 9087649802810055e150b2fbfcc3c197a761f28a
This commit is contained in:
parent
69c6d06ed8
commit
79c61a2d86
@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
|
||||
To run training, pass a yaml config file, followed by a list of overridden arguments.
|
||||
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
|
||||
```shell
|
||||
dataset_args=data_source_args.dataset_args
|
||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||
```
|
||||
|
||||
@ -85,7 +85,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
|
||||
|
||||
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
|
||||
```shell
|
||||
dataset_args=data_source_args.dataset_args
|
||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||
```
|
||||
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
|
||||
@ -236,7 +236,7 @@ generic_model_args: GenericModel
|
||||
╘== ReductionFeatureAggregator
|
||||
solver_args: init_optimizer
|
||||
data_source_args: ImplicitronDataSource
|
||||
└-- dataset_args
|
||||
└-- dataset_map_provider_*_args
|
||||
└-- dataloader_args
|
||||
```
|
||||
|
||||
|
@ -6,6 +6,7 @@ architecture: generic
|
||||
visualize_interval: 0
|
||||
visdom_port: 8097
|
||||
data_source_args:
|
||||
dataset_provider_class_type: JsonIndexDatasetMapProvider
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
@ -21,7 +22,7 @@ data_source_args:
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
|
@ -17,9 +17,9 @@ data_source_args:
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
assert_single_seq: false
|
||||
dataset_name: co3d_multisequence
|
||||
task_str: multisequence
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
|
@ -9,8 +9,8 @@ data_source_args:
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_args:
|
||||
dataset_name: co3d_singlesequence
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
dataset_name: singlesequence
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
test_restrict_sequence_id: 0
|
||||
|
@ -67,7 +67,7 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
@ -552,7 +552,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
def _eval_and_dump(
|
||||
cfg,
|
||||
task: Task,
|
||||
datasets: Datasets,
|
||||
datasets: DatasetMap,
|
||||
dataloaders: Dataloaders,
|
||||
model,
|
||||
stats,
|
||||
|
@ -24,8 +24,7 @@ import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
@ -296,7 +295,7 @@ def export_scenes(
|
||||
output_directory: Optional[str] = None,
|
||||
render_size: Tuple[int, int] = (512, 512),
|
||||
video_size: Optional[Tuple[int, int]] = None,
|
||||
split: str = "train", # train | test
|
||||
split: str = "train", # train | val | test
|
||||
n_source_views: int = 9,
|
||||
n_eval_cameras: int = 40,
|
||||
visdom_server="http://127.0.0.1",
|
||||
@ -324,14 +323,15 @@ def export_scenes(
|
||||
config.gpu_idx = gpu_idx
|
||||
config.exp_dir = exp_dir
|
||||
# important so that the CO3D dataset gets loaded in full
|
||||
config.data_source_args.dataset_args.test_on_train = False
|
||||
dataset_args = (
|
||||
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataset_args.test_on_train = False
|
||||
# Set the rendering image size
|
||||
config.generic_model_args.render_image_width = render_size[0]
|
||||
config.generic_model_args.render_image_height = render_size[1]
|
||||
if restrict_sequence_name is not None:
|
||||
config.data_source_args.dataset_args.restrict_sequence_name = (
|
||||
restrict_sequence_name
|
||||
)
|
||||
dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||
|
||||
# Set up the CUDA env for the visualization
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
@ -344,8 +344,8 @@ def export_scenes(
|
||||
|
||||
# Setup the dataset
|
||||
datasource = ImplicitronDataSource(**config.data_source_args)
|
||||
datasets = dataset_zoo(**datasource.dataset_args)
|
||||
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
|
||||
dataset_map = datasource.dataset_map_provider.get_dataset_map()
|
||||
dataset = dataset_map[split]
|
||||
if dataset is None:
|
||||
raise ValueError(f"{split} dataset not provided")
|
||||
|
||||
|
@ -4,19 +4,18 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
get_default_args_field,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
from . import json_index_dataset_map_provider # noqa
|
||||
from .dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from .dataset_zoo import dataset_zoo, Datasets
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
|
||||
|
||||
class DataSourceBase(ReplaceableBase):
|
||||
@ -25,24 +24,31 @@ class DataSourceBase(ReplaceableBase):
|
||||
and DataLoader configuration.
|
||||
"""
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ImplicitronDataSource(DataSourceBase):
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
"""
|
||||
Represents the data used in Implicitron. This is the only implementation
|
||||
of DataSourceBase provided.
|
||||
|
||||
Members:
|
||||
dataset_map_provider_class_type: identifies type for dataset_map_provider.
|
||||
e.g. JsonIndexDatasetMapProvider for Co3D.
|
||||
"""
|
||||
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataset_map_provider: DatasetMapProviderBase
|
||||
dataset_map_provider_class_type: str
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
|
||||
datasets = dataset_zoo(**self.dataset_args)
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
|
||||
datasets = self.dataset_map_provider.get_dataset_map()
|
||||
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
|
||||
return datasets, dataloaders
|
||||
|
||||
def get_task(self) -> Task:
|
||||
eval_task = self.dataset_args["dataset_name"].split("_")[-1]
|
||||
return Task(eval_task)
|
||||
return self.dataset_map_provider.get_task()
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_zoo import Datasets
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class Dataloaders:
|
||||
|
||||
|
||||
def dataloader_zoo(
|
||||
datasets: Datasets,
|
||||
datasets: DatasetMap,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
dataset_len: int = 1000,
|
||||
@ -49,7 +49,6 @@ def dataloader_zoo(
|
||||
Args:
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
dataset_name: The name of the returned dataset.
|
||||
batch_size: The size of the batch of the dataloader.
|
||||
num_workers: Number data-loading threads.
|
||||
dataset_len: The number of batches in a training epoch.
|
||||
|
71
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
71
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetMap:
|
||||
"""
|
||||
A collection of datasets for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataset for training
|
||||
val: a dataset for validating during training
|
||||
test: a dataset for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[ImplicitronDatasetBase]
|
||||
val: Optional[ImplicitronDatasetBase]
|
||||
test: Optional[ImplicitronDatasetBase]
|
||||
|
||||
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
|
||||
"""
|
||||
Get one of the datasets by key (name of data split)
|
||||
"""
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||
return getattr(self, split)
|
||||
|
||||
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
|
||||
"""
|
||||
Iterator over all datasets.
|
||||
"""
|
||||
if self.train is not None:
|
||||
yield self.train
|
||||
if self.val is not None:
|
||||
yield self.val
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
|
||||
|
||||
class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for a provider of training / validation and testing
|
||||
dataset objects.
|
||||
"""
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
"""
|
||||
Returns:
|
||||
An object containing the torch.Dataset objects in train/val/test fields.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
@ -7,13 +7,13 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Sequence
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
from .implicitron_dataset import ImplicitronDataset
|
||||
from .utils import (
|
||||
DATASET_TYPE_KNOWN,
|
||||
@ -34,6 +34,11 @@ DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_default_config() -> DictConfig:
|
||||
return DictConfig(DATASET_CONFIGS["default"])
|
||||
|
||||
|
||||
# fmt: off
|
||||
CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
||||
@ -53,59 +58,16 @@ CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Datasets:
|
||||
@registry.register
|
||||
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
"""
|
||||
A provider of datasets for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataset for training
|
||||
val: a dataset for validating during training
|
||||
test: a dataset for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[ImplicitronDatasetBase]
|
||||
val: Optional[ImplicitronDatasetBase]
|
||||
test: Optional[ImplicitronDatasetBase]
|
||||
|
||||
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
|
||||
"""
|
||||
Iterator over all datasets.
|
||||
"""
|
||||
if self.train is not None:
|
||||
yield self.train
|
||||
if self.val is not None:
|
||||
yield self.val
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
|
||||
def dataset_zoo(
|
||||
dataset_name: str = "co3d_singlesequence",
|
||||
dataset_root: str = _CO3D_DATASET_ROOT,
|
||||
category: str = "DEFAULT",
|
||||
limit_to: int = -1,
|
||||
limit_sequences_to: int = -1,
|
||||
n_frames_per_sequence: int = -1,
|
||||
test_on_train: bool = False,
|
||||
load_point_clouds: bool = False,
|
||||
mask_images: bool = False,
|
||||
mask_depths: bool = False,
|
||||
restrict_sequence_name: Sequence[str] = (),
|
||||
test_restrict_sequence_id: int = -1,
|
||||
assert_single_seq: bool = False,
|
||||
only_test_set: bool = False,
|
||||
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
|
||||
path_manager: Optional[PathManager] = None,
|
||||
) -> Datasets:
|
||||
"""
|
||||
Generates the training / validation and testing dataset objects.
|
||||
Generates the training / validation and testing dataset objects for
|
||||
a dataset laid out on disk like Co3D, with annotations in json files.
|
||||
|
||||
Args:
|
||||
dataset_name: The name of the returned dataset.
|
||||
dataset_root: The root folder of the dataset.
|
||||
category: The object category of the dataset.
|
||||
task_str: "multisequence" or "singlesequence".
|
||||
dataset_root: The root folder of the dataset.
|
||||
limit_to: Limit the dataset to the first #limit_to frames.
|
||||
limit_sequences_to: Limit the dataset to the first
|
||||
#limit_sequences_to sequences.
|
||||
@ -119,58 +81,78 @@ def dataset_zoo(
|
||||
restrict_sequence_name: Restrict the dataset sequences to the ones
|
||||
present in the given list of names.
|
||||
test_restrict_sequence_id: The ID of the loaded sequence.
|
||||
Active for dataset_name='co3d_singlesequence'.
|
||||
Active for task_str='singlesequence'.
|
||||
assert_single_seq: Assert that only frames from a single sequence
|
||||
are present in all generated datasets.
|
||||
only_test_set: Load only the test set.
|
||||
aux_dataset_kwargs: Specifies additional arguments to the
|
||||
ImplicitronDataset constructor call.
|
||||
|
||||
Returns:
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
path_manager: Optional[PathManager] for interpreting paths
|
||||
"""
|
||||
if only_test_set and test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
# TODO:
|
||||
# - implement loading multiple categories
|
||||
category: str
|
||||
task_str: str = "singlesequence"
|
||||
dataset_root: str = _CO3D_DATASET_ROOT
|
||||
limit_to: int = -1
|
||||
limit_sequences_to: int = -1
|
||||
n_frames_per_sequence: int = -1
|
||||
test_on_train: bool = False
|
||||
load_point_clouds: bool = False
|
||||
mask_images: bool = False
|
||||
mask_depths: bool = False
|
||||
restrict_sequence_name: Sequence[str] = ()
|
||||
test_restrict_sequence_id: int = -1
|
||||
assert_single_seq: bool = False
|
||||
only_test_set: bool = False
|
||||
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
|
||||
path_manager: Any = None
|
||||
|
||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
# TODO:
|
||||
# - implement loading multiple categories
|
||||
|
||||
frame_file = os.path.join(
|
||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||
)
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||
)
|
||||
subset_lists_file = os.path.join(
|
||||
self.dataset_root, self.category, "set_lists.json"
|
||||
)
|
||||
common_kwargs = {
|
||||
"dataset_root": dataset_root,
|
||||
"limit_to": limit_to,
|
||||
"limit_sequences_to": limit_sequences_to,
|
||||
"load_point_clouds": load_point_clouds,
|
||||
"mask_images": mask_images,
|
||||
"mask_depths": mask_depths,
|
||||
"path_manager": path_manager,
|
||||
"dataset_root": self.dataset_root,
|
||||
"limit_to": self.limit_to,
|
||||
"limit_sequences_to": self.limit_sequences_to,
|
||||
"load_point_clouds": self.load_point_clouds,
|
||||
"mask_images": self.mask_images,
|
||||
"mask_depths": self.mask_depths,
|
||||
"path_manager": self.path_manager,
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
**aux_dataset_kwargs,
|
||||
**self.aux_dataset_kwargs,
|
||||
}
|
||||
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
# to the names of the subsets in the CO3D dataset.
|
||||
set_names_mapping = _get_co3d_set_names_mapping(
|
||||
dataset_name,
|
||||
test_on_train,
|
||||
only_test_set,
|
||||
self.get_task(),
|
||||
self.test_on_train,
|
||||
self.only_test_set,
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
task = dataset_name.split("_")[-1]
|
||||
batch_indices_path = os.path.join(
|
||||
dataset_root,
|
||||
category,
|
||||
f"eval_batches_{task}.json",
|
||||
self.dataset_root,
|
||||
self.category,
|
||||
f"eval_batches_{self.task_str}.json",
|
||||
)
|
||||
if path_manager is not None:
|
||||
batch_indices_path = path_manager.get_local_path(batch_indices_path)
|
||||
if self.path_manager is not None:
|
||||
batch_indices_path = self.path_manager.get_local_path(batch_indices_path)
|
||||
if not os.path.isfile(batch_indices_path):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
@ -181,25 +163,31 @@ def dataset_zoo(
|
||||
|
||||
with open(batch_indices_path, "r") as f:
|
||||
eval_batch_index = json.load(f)
|
||||
restrict_sequence_name = self.restrict_sequence_name
|
||||
|
||||
if task == "singlesequence":
|
||||
assert (
|
||||
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
|
||||
), (
|
||||
"Please specify an integer id 'test_restrict_sequence_id'"
|
||||
+ " of the sequence considered for 'singlesequence'"
|
||||
+ " training and evaluation."
|
||||
)
|
||||
assert len(restrict_sequence_name) == 0, (
|
||||
"For the 'singlesequence' task, the restrict_sequence_name has"
|
||||
" to be unset while test_restrict_sequence_id has to be set to an"
|
||||
" integer defining the order of the evaluation sequence."
|
||||
)
|
||||
if self.get_task() == Task.SINGLE_SEQUENCE:
|
||||
if (
|
||||
self.test_restrict_sequence_id is None
|
||||
or self.test_restrict_sequence_id < 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Please specify an integer id 'test_restrict_sequence_id'"
|
||||
+ " of the sequence considered for 'singlesequence'"
|
||||
+ " training and evaluation."
|
||||
)
|
||||
if len(self.restrict_sequence_name) > 0:
|
||||
raise ValueError(
|
||||
"For the 'singlesequence' task, the restrict_sequence_name has"
|
||||
" to be unset while test_restrict_sequence_id has to be set to an"
|
||||
" integer defining the order of the evaluation sequence."
|
||||
)
|
||||
# a sort-stable set() equivalent:
|
||||
eval_batches_sequence_names = list(
|
||||
{b[0][0]: None for b in eval_batch_index}.keys()
|
||||
)
|
||||
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
|
||||
eval_sequence_name = eval_batches_sequence_names[
|
||||
self.test_restrict_sequence_id
|
||||
]
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] == eval_sequence_name
|
||||
]
|
||||
@ -207,14 +195,14 @@ def dataset_zoo(
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
|
||||
train_dataset = None
|
||||
if not only_test_set:
|
||||
if not self.only_test_set:
|
||||
train_dataset = ImplicitronDataset(
|
||||
n_frames_per_sequence=n_frames_per_sequence,
|
||||
n_frames_per_sequence=self.n_frames_per_sequence,
|
||||
subsets=set_names_mapping["train"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
**common_kwargs,
|
||||
)
|
||||
if test_on_train:
|
||||
if self.test_on_train:
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
@ -237,29 +225,26 @@ def dataset_zoo(
|
||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index
|
||||
)
|
||||
datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
if self.assert_single_seq:
|
||||
# check there's only one sequence in all datasets
|
||||
sequence_names = {
|
||||
sequence_name
|
||||
for dset in datasets.iter_datasets()
|
||||
for sequence_name in dset.sequence_names()
|
||||
}
|
||||
if len(sequence_names) > 1:
|
||||
raise ValueError("Multiple sequences loaded but expected one")
|
||||
|
||||
if assert_single_seq:
|
||||
# check there's only one sequence in all datasets
|
||||
sequence_names = {
|
||||
sequence_name
|
||||
for dset in datasets.iter_datasets()
|
||||
for sequence_name in dset.sequence_names()
|
||||
}
|
||||
if len(sequence_names) > 1:
|
||||
raise ValueError("Multiple sequences loaded but expected one")
|
||||
return datasets
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
enable_get_default_args(dataset_zoo)
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
dataset_name: str,
|
||||
task: Task,
|
||||
test_on_train: bool,
|
||||
only_test: bool,
|
||||
) -> Dict[str, List[str]]:
|
||||
@ -273,7 +258,7 @@ def _get_co3d_set_names_mapping(
|
||||
- val (if not test_on_train)
|
||||
- test (if not test_on_train)
|
||||
"""
|
||||
single_seq = dataset_name == "co3d_singlesequence"
|
||||
single_seq = task == Task.SINGLE_SEQUENCE
|
||||
|
||||
if only_test:
|
||||
set_names_mapping = {}
|
@ -12,11 +12,12 @@ from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
import lpips
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
aggregate_nvs_results,
|
||||
@ -101,23 +102,21 @@ def evaluate_dbir_for_category(
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
dataset_name = {
|
||||
Task.SINGLE_SEQUENCE: "co3d_singlesequence",
|
||||
Task.MULTI_SEQUENCE: "co3d_multisequence",
|
||||
}[task]
|
||||
|
||||
datasets = dataset_zoo(
|
||||
category=category,
|
||||
dataset_root=os.environ["CO3D_DATASET_ROOT"],
|
||||
assert_single_seq=task == Task.SINGLE_SEQUENCE,
|
||||
dataset_name=dataset_name,
|
||||
test_on_train=False,
|
||||
load_point_clouds=True,
|
||||
test_restrict_sequence_id=single_sequence_id,
|
||||
path_manager=path_manager,
|
||||
dataset_map_provider_args = {
|
||||
"category": category,
|
||||
"dataset_root": os.environ["CO3D_DATASET_ROOT"],
|
||||
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
|
||||
"task_str": task.value,
|
||||
"test_on_train": False,
|
||||
"load_point_clouds": True,
|
||||
"test_restrict_sequence_id": single_sequence_id,
|
||||
"path_manager": path_manager,
|
||||
}
|
||||
data_source = ImplicitronDataSource(
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
|
||||
)
|
||||
|
||||
dataloaders = dataloader_zoo(datasets)
|
||||
datasets, dataloaders = data_source.get_datasets_and_dataloaders()
|
||||
|
||||
test_dataset = datasets.test
|
||||
test_dataloader = dataloaders.test
|
||||
|
33
tests/implicitron/data/data_source.yaml
Normal file
33
tests/implicitron/data/data_source.yaml
Normal file
@ -0,0 +1,33 @@
|
||||
dataset_map_provider_class_type: ???
|
||||
dataloader_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
category: ???
|
||||
task_str: singlesequence
|
||||
dataset_root: ''
|
||||
limit_to: -1
|
||||
limit_sequences_to: -1
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: false
|
||||
load_point_clouds: false
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
restrict_sequence_name: []
|
||||
test_restrict_sequence_id: -1
|
||||
assert_single_seq: false
|
||||
only_test_set: false
|
||||
aux_dataset_kwargs:
|
||||
box_crop: true
|
||||
box_crop_context: 0.3
|
||||
image_width: 800
|
||||
image_height: 800
|
||||
remove_empty_masks: true
|
||||
path_manager: null
|
@ -118,6 +118,6 @@ implicit_function_IdrFeatureField_args:
|
||||
bias: 1.0
|
||||
skip_in: []
|
||||
weight_norm: true
|
||||
n_harmonic_functions_xyz: 0
|
||||
n_harmonic_functions_xyz: 1729
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
|
@ -70,6 +70,9 @@ class TestGenericModel(unittest.TestCase):
|
||||
"AngleWeightedIdentityFeatureAggregator"
|
||||
)
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
idr_args = args.implicit_function_IdrFeatureField_args
|
||||
idr_args.n_harmonic_functions_xyz = 1729
|
||||
|
||||
args.renderer_class_type = "LSTMRenderer"
|
||||
gm = GenericModel(**args)
|
||||
self.assertIsInstance(gm.renderer, LSTMRenderer)
|
||||
@ -78,6 +81,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
AngleWeightedIdentityFeatureAggregator,
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
32
tests/implicitron/test_data_source.py
Normal file
32
tests/implicitron/test_data_source.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import get_tests_dir
|
||||
else:
|
||||
from tests.common_testing import get_tests_dir
|
||||
|
||||
DATA_DIR = get_tests_dir() / "implicitron/data"
|
||||
DEBUG: bool = False
|
||||
|
||||
|
||||
class TestDataSource(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
|
||||
def test_one(self):
|
||||
cfg = get_default_args(ImplicitronDataSource)
|
||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "data_source.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())
|
Loading…
x
Reference in New Issue
Block a user