dataset_map_provider

Summary: replace dataset_zoo with a pluggable DatasetMapProvider. The logic is now in annotated_file_dataset_map_provider.

Reviewed By: shapovalov

Differential Revision: D36443965

fbshipit-source-id: 9087649802810055e150b2fbfcc3c197a761f28a
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 69c6d06ed8
commit 79c61a2d86
15 changed files with 305 additions and 175 deletions

View File

@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments. To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run: For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell ```shell
dataset_args=data_source_args.dataset_args dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR> pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
``` ```
@ -85,7 +85,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run: E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell ```shell
dataset_args=data_source_args.dataset_args dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
``` ```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`. Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
@ -236,7 +236,7 @@ generic_model_args: GenericModel
╘== ReductionFeatureAggregator ╘== ReductionFeatureAggregator
solver_args: init_optimizer solver_args: init_optimizer
data_source_args: ImplicitronDataSource data_source_args: ImplicitronDataSource
└-- dataset_args └-- dataset_map_provider_*_args
└-- dataloader_args └-- dataloader_args
``` ```

View File

@ -6,6 +6,7 @@ architecture: generic
visualize_interval: 0 visualize_interval: 0
visdom_port: 8097 visdom_port: 8097
data_source_args: data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args: dataloader_args:
batch_size: 10 batch_size: 10
dataset_len: 1000 dataset_len: 1000
@ -21,7 +22,7 @@ data_source_args:
- 8 - 8
- 9 - 9
- 10 - 10
dataset_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT} dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false load_point_clouds: false
mask_depths: false mask_depths: false

View File

@ -17,9 +17,9 @@ data_source_args:
- 8 - 8
- 9 - 9
- 10 - 10
dataset_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: false assert_single_seq: false
dataset_name: co3d_multisequence task_str: multisequence
load_point_clouds: false load_point_clouds: false
mask_depths: false mask_depths: false
mask_images: false mask_images: false

View File

@ -9,8 +9,8 @@ data_source_args:
num_workers: 8 num_workers: 8
images_per_seq_options: images_per_seq_options:
- 2 - 2
dataset_args: dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_name: co3d_singlesequence dataset_name: singlesequence
assert_single_seq: true assert_single_seq: true
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_restrict_sequence_id: 0 test_restrict_sequence_id: 0

View File

@ -67,7 +67,7 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
@ -552,7 +552,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
def _eval_and_dump( def _eval_and_dump(
cfg, cfg,
task: Task, task: Task,
datasets: Datasets, datasets: DatasetMap,
dataloaders: Dataloaders, dataloaders: Dataloaders,
model, model,
stats, stats,

View File

@ -24,8 +24,7 @@ import torch.nn.functional as Fu
from experiment import init_model from experiment import init_model
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode from pytorch3d.implicitron.models.base_model import EvaluationMode
@ -296,7 +295,7 @@ def export_scenes(
output_directory: Optional[str] = None, output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512), render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None, video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test split: str = "train", # train | val | test
n_source_views: int = 9, n_source_views: int = 9,
n_eval_cameras: int = 40, n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1", visdom_server="http://127.0.0.1",
@ -324,14 +323,15 @@ def export_scenes(
config.gpu_idx = gpu_idx config.gpu_idx = gpu_idx
config.exp_dir = exp_dir config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full # important so that the CO3D dataset gets loaded in full
config.data_source_args.dataset_args.test_on_train = False dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
# Set the rendering image size # Set the rendering image size
config.generic_model_args.render_image_width = render_size[0] config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1] config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None: if restrict_sequence_name is not None:
config.data_source_args.dataset_args.restrict_sequence_name = ( dataset_args.restrict_sequence_name = restrict_sequence_name
restrict_sequence_name
)
# Set up the CUDA env for the visualization # Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -344,8 +344,8 @@ def export_scenes(
# Setup the dataset # Setup the dataset
datasource = ImplicitronDataSource(**config.data_source_args) datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args) dataset_map = datasource.dataset_map_provider.get_dataset_map()
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None) dataset = dataset_map[split]
if dataset is None: if dataset is None:
raise ValueError(f"{split} dataset not provided") raise ValueError(f"{split} dataset not provided")

View File

@ -4,19 +4,18 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Tuple from typing import Tuple
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase from pytorch3d.implicitron.tools.config import (
get_default_args_field,
ReplaceableBase,
run_auto_creation,
)
from . import json_index_dataset_map_provider # noqa
from .dataloader_zoo import dataloader_zoo, Dataloaders from .dataloader_zoo import dataloader_zoo, Dataloaders
from .dataset_zoo import dataset_zoo, Datasets from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DataSourceBase(ReplaceableBase): class DataSourceBase(ReplaceableBase):
@ -25,24 +24,31 @@ class DataSourceBase(ReplaceableBase):
and DataLoader configuration. and DataLoader configuration.
""" """
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
raise NotImplementedError() raise NotImplementedError()
class ImplicitronDataSource(DataSourceBase): class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
""" """
Represents the data used in Implicitron. This is the only implementation Represents the data used in Implicitron. This is the only implementation
of DataSourceBase provided. of DataSourceBase provided.
Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D.
""" """
dataset_args: DictConfig = get_default_args_field(dataset_zoo) dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]: def __post_init__(self):
datasets = dataset_zoo(**self.dataset_args) run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = dataloader_zoo(datasets, **self.dataloader_args) dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
return datasets, dataloaders return datasets, dataloaders
def get_task(self) -> Task: def get_task(self) -> Task:
eval_task = self.dataset_args["dataset_name"].split("_")[-1] return self.dataset_map_provider.get_task()
return Task(eval_task)

View File

@ -11,7 +11,7 @@ import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args from pytorch3d.implicitron.tools.config import enable_get_default_args
from .dataset_base import FrameData, ImplicitronDatasetBase from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_zoo import Datasets from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler from .scene_batch_sampler import SceneBatchSampler
@ -33,7 +33,7 @@ class Dataloaders:
def dataloader_zoo( def dataloader_zoo(
datasets: Datasets, datasets: DatasetMap,
batch_size: int = 1, batch_size: int = 1,
num_workers: int = 0, num_workers: int = 0,
dataset_len: int = 1000, dataset_len: int = 1000,
@ -49,7 +49,6 @@ def dataloader_zoo(
Args: Args:
datasets: A dictionary containing the datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs. `"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader. batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads. num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch. dataset_len: The number of batches in a training epoch.

View File

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from .dataset_base import ImplicitronDatasetBase
@dataclass
class DatasetMap:
"""
A collection of datasets for implicitron.
Members:
train: a dataset for training
val: a dataset for validating during training
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
"""
Get one of the datasets by key (name of data split)
"""
if split not in ["train", "val", "test"]:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
"""
Iterator over all datasets.
"""
if self.train is not None:
yield self.train
if self.val is not None:
yield self.val
if self.test is not None:
yield self.test
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DatasetMapProviderBase(ReplaceableBase):
"""
Base class for a provider of training / validation and testing
dataset objects.
"""
def get_dataset_map(self) -> DatasetMap:
"""
Returns:
An object containing the torch.Dataset objects in train/val/test fields.
"""
raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()

View File

@ -7,13 +7,13 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import field
from typing import Any, Dict, Iterator, List, Optional, Sequence from typing import Any, Dict, List, Sequence
from iopath.common.file_io import PathManager from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import enable_get_default_args from pytorch3d.implicitron.tools.config import registry
from .dataset_base import ImplicitronDatasetBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .implicitron_dataset import ImplicitronDataset from .implicitron_dataset import ImplicitronDataset
from .utils import ( from .utils import (
DATASET_TYPE_KNOWN, DATASET_TYPE_KNOWN,
@ -34,6 +34,11 @@ DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
} }
} }
def _make_default_config() -> DictConfig:
return DictConfig(DATASET_CONFIGS["default"])
# fmt: off # fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([ CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv", "baseballbat", "banana", "bicycle", "microwave", "tv",
@ -53,59 +58,16 @@ CO3D_CATEGORIES: List[str] = list(reversed([
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "") _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
@dataclass @registry.register
class Datasets: class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
""" """
A provider of datasets for implicitron. Generates the training / validation and testing dataset objects for
a dataset laid out on disk like Co3D, with annotations in json files.
Members:
train: a dataset for training
val: a dataset for validating during training
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
"""
Iterator over all datasets.
"""
if self.train is not None:
yield self.train
if self.val is not None:
yield self.val
if self.test is not None:
yield self.test
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
category: str = "DEFAULT",
limit_to: int = -1,
limit_sequences_to: int = -1,
n_frames_per_sequence: int = -1,
test_on_train: bool = False,
load_point_clouds: bool = False,
mask_images: bool = False,
mask_depths: bool = False,
restrict_sequence_name: Sequence[str] = (),
test_restrict_sequence_id: int = -1,
assert_single_seq: bool = False,
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Datasets:
"""
Generates the training / validation and testing dataset objects.
Args: Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset. category: The object category of the dataset.
task_str: "multisequence" or "singlesequence".
dataset_root: The root folder of the dataset.
limit_to: Limit the dataset to the first #limit_to frames. limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences. #limit_sequences_to sequences.
@ -119,58 +81,78 @@ def dataset_zoo(
restrict_sequence_name: Restrict the dataset sequences to the ones restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names. present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence. test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'. Active for task_str='singlesequence'.
assert_single_seq: Assert that only frames from a single sequence assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets. are present in all generated datasets.
only_test_set: Load only the test set. only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call. ImplicitronDataset constructor call.
path_manager: Optional[PathManager] for interpreting paths
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
""" """
if only_test_set and test_on_train:
category: str
task_str: str = "singlesequence"
dataset_root: str = _CO3D_DATASET_ROOT
limit_to: int = -1
limit_sequences_to: int = -1
n_frames_per_sequence: int = -1
test_on_train: bool = False
load_point_clouds: bool = False
mask_images: bool = False
mask_depths: bool = False
restrict_sequence_name: Sequence[str] = ()
test_restrict_sequence_id: int = -1
assert_single_seq: bool = False
only_test_set: bool = False
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
path_manager: Any = None
def get_dataset_map(self) -> DatasetMap:
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train") raise ValueError("Cannot have only_test_set and test_on_train")
# TODO: # TODO:
# - implement loading multiple categories # - implement loading multiple categories
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]: frame_file = os.path.join(
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") self.dataset_root, self.category, "frame_annotations.jgz"
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") )
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") sequence_file = os.path.join(
self.dataset_root, self.category, "sequence_annotations.jgz"
)
subset_lists_file = os.path.join(
self.dataset_root, self.category, "set_lists.json"
)
common_kwargs = { common_kwargs = {
"dataset_root": dataset_root, "dataset_root": self.dataset_root,
"limit_to": limit_to, "limit_to": self.limit_to,
"limit_sequences_to": limit_sequences_to, "limit_sequences_to": self.limit_sequences_to,
"load_point_clouds": load_point_clouds, "load_point_clouds": self.load_point_clouds,
"mask_images": mask_images, "mask_images": self.mask_images,
"mask_depths": mask_depths, "mask_depths": self.mask_depths,
"path_manager": path_manager, "path_manager": self.path_manager,
"frame_annotations_file": frame_file, "frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file, "sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file, "subset_lists_file": subset_lists_file,
**aux_dataset_kwargs, **self.aux_dataset_kwargs,
} }
# This maps the common names of the dataset subsets ("train"/"val"/"test") # This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset. # to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping( set_names_mapping = _get_co3d_set_names_mapping(
dataset_name, self.get_task(),
test_on_train, self.test_on_train,
only_test_set, self.only_test_set,
) )
# load the evaluation batches # load the evaluation batches
task = dataset_name.split("_")[-1]
batch_indices_path = os.path.join( batch_indices_path = os.path.join(
dataset_root, self.dataset_root,
category, self.category,
f"eval_batches_{task}.json", f"eval_batches_{self.task_str}.json",
) )
if path_manager is not None: if self.path_manager is not None:
batch_indices_path = path_manager.get_local_path(batch_indices_path) batch_indices_path = self.path_manager.get_local_path(batch_indices_path)
if not os.path.isfile(batch_indices_path): if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist. # The batch indices file does not exist.
# Most probably the user has not specified the root folder. # Most probably the user has not specified the root folder.
@ -181,16 +163,20 @@ def dataset_zoo(
with open(batch_indices_path, "r") as f: with open(batch_indices_path, "r") as f:
eval_batch_index = json.load(f) eval_batch_index = json.load(f)
restrict_sequence_name = self.restrict_sequence_name
if task == "singlesequence": if self.get_task() == Task.SINGLE_SEQUENCE:
assert ( if (
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0 self.test_restrict_sequence_id is None
), ( or self.test_restrict_sequence_id < 0
):
raise ValueError(
"Please specify an integer id 'test_restrict_sequence_id'" "Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'" + " of the sequence considered for 'singlesequence'"
+ " training and evaluation." + " training and evaluation."
) )
assert len(restrict_sequence_name) == 0, ( if len(self.restrict_sequence_name) > 0:
raise ValueError(
"For the 'singlesequence' task, the restrict_sequence_name has" "For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an" " to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence." " integer defining the order of the evaluation sequence."
@ -199,7 +185,9 @@ def dataset_zoo(
eval_batches_sequence_names = list( eval_batches_sequence_names = list(
{b[0][0]: None for b in eval_batch_index}.keys() {b[0][0]: None for b in eval_batch_index}.keys()
) )
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id] eval_sequence_name = eval_batches_sequence_names[
self.test_restrict_sequence_id
]
eval_batch_index = [ eval_batch_index = [
b for b in eval_batch_index if b[0][0] == eval_sequence_name b for b in eval_batch_index if b[0][0] == eval_sequence_name
] ]
@ -207,14 +195,14 @@ def dataset_zoo(
restrict_sequence_name = [eval_sequence_name] restrict_sequence_name = [eval_sequence_name]
train_dataset = None train_dataset = None
if not only_test_set: if not self.only_test_set:
train_dataset = ImplicitronDataset( train_dataset = ImplicitronDataset(
n_frames_per_sequence=n_frames_per_sequence, n_frames_per_sequence=self.n_frames_per_sequence,
subsets=set_names_mapping["train"], subsets=set_names_mapping["train"],
pick_sequence=restrict_sequence_name, pick_sequence=restrict_sequence_name,
**common_kwargs, **common_kwargs,
) )
if test_on_train: if self.test_on_train:
assert train_dataset is not None assert train_dataset is not None
val_dataset = test_dataset = train_dataset val_dataset = test_dataset = train_dataset
else: else:
@ -237,12 +225,9 @@ def dataset_zoo(
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index( test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index eval_batch_index
) )
datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset) datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
else: if self.assert_single_seq:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if assert_single_seq:
# check there's only one sequence in all datasets # check there's only one sequence in all datasets
sequence_names = { sequence_names = {
sequence_name sequence_name
@ -254,12 +239,12 @@ def dataset_zoo(
return datasets return datasets
def get_task(self) -> Task:
enable_get_default_args(dataset_zoo) return Task(self.task_str)
def _get_co3d_set_names_mapping( def _get_co3d_set_names_mapping(
dataset_name: str, task: Task,
test_on_train: bool, test_on_train: bool,
only_test: bool, only_test: bool,
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
@ -273,7 +258,7 @@ def _get_co3d_set_names_mapping(
- val (if not test_on_train) - val (if not test_on_train)
- test (if not test_on_train) - test (if not test_on_train)
""" """
single_seq = dataset_name == "co3d_singlesequence" single_seq = task == Task.SINGLE_SEQUENCE
if only_test: if only_test:
set_names_mapping = {} set_names_mapping = {}

View File

@ -12,11 +12,12 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results, aggregate_nvs_results,
@ -101,23 +102,21 @@ def evaluate_dbir_for_category(
torch.manual_seed(42) torch.manual_seed(42)
dataset_name = { dataset_map_provider_args = {
Task.SINGLE_SEQUENCE: "co3d_singlesequence", "category": category,
Task.MULTI_SEQUENCE: "co3d_multisequence", "dataset_root": os.environ["CO3D_DATASET_ROOT"],
}[task] "assert_single_seq": task == Task.SINGLE_SEQUENCE,
"task_str": task.value,
datasets = dataset_zoo( "test_on_train": False,
category=category, "load_point_clouds": True,
dataset_root=os.environ["CO3D_DATASET_ROOT"], "test_restrict_sequence_id": single_sequence_id,
assert_single_seq=task == Task.SINGLE_SEQUENCE, "path_manager": path_manager,
dataset_name=dataset_name, }
test_on_train=False, data_source = ImplicitronDataSource(
load_point_clouds=True, dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
test_restrict_sequence_id=single_sequence_id,
path_manager=path_manager,
) )
dataloaders = dataloader_zoo(datasets) datasets, dataloaders = data_source.get_datasets_and_dataloaders()
test_dataset = datasets.test test_dataset = datasets.test
test_dataloader = dataloaders.test test_dataloader = dataloaders.test

View File

@ -0,0 +1,33 @@
dataset_map_provider_class_type: ???
dataloader_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1
test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: []
test_restrict_sequence_id: -1
assert_single_seq: false
only_test_set: false
aux_dataset_kwargs:
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
path_manager: null

View File

@ -118,6 +118,6 @@ implicit_function_IdrFeatureField_args:
bias: 1.0 bias: 1.0
skip_in: [] skip_in: []
weight_norm: true weight_norm: true
n_harmonic_functions_xyz: 0 n_harmonic_functions_xyz: 1729
pooled_feature_dim: 0 pooled_feature_dim: 0
encoding_dim: 0 encoding_dim: 0

View File

@ -70,6 +70,9 @@ class TestGenericModel(unittest.TestCase):
"AngleWeightedIdentityFeatureAggregator" "AngleWeightedIdentityFeatureAggregator"
) )
args.implicit_function_class_type = "IdrFeatureField" args.implicit_function_class_type = "IdrFeatureField"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729
args.renderer_class_type = "LSTMRenderer" args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args) gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer) self.assertIsInstance(gm.renderer, LSTMRenderer)
@ -78,6 +81,7 @@ class TestGenericModel(unittest.TestCase):
AngleWeightedIdentityFeatureAggregator, AngleWeightedIdentityFeatureAggregator,
) )
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor) self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function")) self.assertFalse(hasattr(gm, "implicit_function"))

View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.tools.config import get_default_args
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
class TestDataSource(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_one(self):
cfg = get_default_args(ImplicitronDataSource)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG:
(DATA_DIR / "data_source.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())