dataset_map_provider

Summary: replace dataset_zoo with a pluggable DatasetMapProvider. The logic is now in annotated_file_dataset_map_provider.

Reviewed By: shapovalov

Differential Revision: D36443965

fbshipit-source-id: 9087649802810055e150b2fbfcc3c197a761f28a
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 69c6d06ed8
commit 79c61a2d86
15 changed files with 305 additions and 175 deletions

View File

@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
dataset_args=data_source_args.dataset_args
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
```
@ -85,7 +85,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
dataset_args=data_source_args.dataset_args
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
@ -236,7 +236,7 @@ generic_model_args: GenericModel
╘== ReductionFeatureAggregator
solver_args: init_optimizer
data_source_args: ImplicitronDataSource
└-- dataset_args
└-- dataset_map_provider_*_args
└-- dataloader_args
```

View File

@ -6,6 +6,7 @@ architecture: generic
visualize_interval: 0
visdom_port: 8097
data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args:
batch_size: 10
dataset_len: 1000
@ -21,7 +22,7 @@ data_source_args:
- 8
- 9
- 10
dataset_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false

View File

@ -17,9 +17,9 @@ data_source_args:
- 8
- 9
- 10
dataset_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: false
dataset_name: co3d_multisequence
task_str: multisequence
load_point_clouds: false
mask_depths: false
mask_images: false

View File

@ -9,8 +9,8 @@ data_source_args:
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_name: singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0

View File

@ -67,7 +67,7 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
@ -552,7 +552,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
def _eval_and_dump(
cfg,
task: Task,
datasets: Datasets,
datasets: DatasetMap,
dataloaders: Dataloaders,
model,
stats,

View File

@ -24,8 +24,7 @@ import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
@ -296,7 +295,7 @@ def export_scenes(
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
split: str = "train", # train | val | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
@ -324,14 +323,15 @@ def export_scenes(
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.data_source_args.dataset_args.test_on_train = False
dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.data_source_args.dataset_args.restrict_sequence_name = (
restrict_sequence_name
)
dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -344,8 +344,8 @@ def export_scenes(
# Setup the dataset
datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
dataset_map = datasource.dataset_map_provider.get_dataset_map()
dataset = dataset_map[split]
if dataset is None:
raise ValueError(f"{split} dataset not provided")

View File

@ -4,19 +4,18 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, ReplaceableBase
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
ReplaceableBase,
run_auto_creation,
)
from . import json_index_dataset_map_provider # noqa
from .dataloader_zoo import dataloader_zoo, Dataloaders
from .dataset_zoo import dataset_zoo, Datasets
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
class DataSourceBase(ReplaceableBase):
@ -25,24 +24,31 @@ class DataSourceBase(ReplaceableBase):
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
raise NotImplementedError()
class ImplicitronDataSource(DataSourceBase):
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
"""
Represents the data used in Implicitron. This is the only implementation
of DataSourceBase provided.
Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D.
"""
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
def get_datasets_and_dataloaders(self) -> Tuple[Datasets, Dataloaders]:
datasets = dataset_zoo(**self.dataset_args)
def __post_init__(self):
run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
return datasets, dataloaders
def get_task(self) -> Task:
eval_task = self.dataset_args["dataset_name"].split("_")[-1]
return Task(eval_task)
return self.dataset_map_provider.get_task()

View File

@ -11,7 +11,7 @@ import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_zoo import Datasets
from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler
@ -33,7 +33,7 @@ class Dataloaders:
def dataloader_zoo(
datasets: Datasets,
datasets: DatasetMap,
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
@ -49,7 +49,6 @@ def dataloader_zoo(
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.

View File

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from .dataset_base import ImplicitronDatasetBase
@dataclass
class DatasetMap:
"""
A collection of datasets for implicitron.
Members:
train: a dataset for training
val: a dataset for validating during training
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
"""
Get one of the datasets by key (name of data split)
"""
if split not in ["train", "val", "test"]:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
"""
Iterator over all datasets.
"""
if self.train is not None:
yield self.train
if self.val is not None:
yield self.val
if self.test is not None:
yield self.test
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DatasetMapProviderBase(ReplaceableBase):
"""
Base class for a provider of training / validation and testing
dataset objects.
"""
def get_dataset_map(self) -> DatasetMap:
"""
Returns:
An object containing the torch.Dataset objects in train/val/test fields.
"""
raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()

View File

@ -7,13 +7,13 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Sequence
from dataclasses import field
from typing import Any, Dict, List, Sequence
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import enable_get_default_args
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import registry
from .dataset_base import ImplicitronDatasetBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .implicitron_dataset import ImplicitronDataset
from .utils import (
DATASET_TYPE_KNOWN,
@ -34,6 +34,11 @@ DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
}
}
def _make_default_config() -> DictConfig:
return DictConfig(DATASET_CONFIGS["default"])
# fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv",
@ -53,59 +58,16 @@ CO3D_CATEGORIES: List[str] = list(reversed([
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
@dataclass
class Datasets:
@registry.register
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
"""
A provider of datasets for implicitron.
Members:
train: a dataset for training
val: a dataset for validating during training
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
"""
Iterator over all datasets.
"""
if self.train is not None:
yield self.train
if self.val is not None:
yield self.val
if self.test is not None:
yield self.test
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
category: str = "DEFAULT",
limit_to: int = -1,
limit_sequences_to: int = -1,
n_frames_per_sequence: int = -1,
test_on_train: bool = False,
load_point_clouds: bool = False,
mask_images: bool = False,
mask_depths: bool = False,
restrict_sequence_name: Sequence[str] = (),
test_restrict_sequence_id: int = -1,
assert_single_seq: bool = False,
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Datasets:
"""
Generates the training / validation and testing dataset objects.
Generates the training / validation and testing dataset objects for
a dataset laid out on disk like Co3D, with annotations in json files.
Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset.
task_str: "multisequence" or "singlesequence".
dataset_root: The root folder of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
@ -119,58 +81,78 @@ def dataset_zoo(
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'.
Active for task_str='singlesequence'.
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call.
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
path_manager: Optional[PathManager] for interpreting paths
"""
if only_test_set and test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
# TODO:
# - implement loading multiple categories
category: str
task_str: str = "singlesequence"
dataset_root: str = _CO3D_DATASET_ROOT
limit_to: int = -1
limit_sequences_to: int = -1
n_frames_per_sequence: int = -1
test_on_train: bool = False
load_point_clouds: bool = False
mask_images: bool = False
mask_depths: bool = False
restrict_sequence_name: Sequence[str] = ()
test_restrict_sequence_id: int = -1
assert_single_seq: bool = False
only_test_set: bool = False
aux_dataset_kwargs: DictConfig = field(default_factory=_make_default_config)
path_manager: Any = None
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
def get_dataset_map(self) -> DatasetMap:
if self.only_test_set and self.test_on_train:
raise ValueError("Cannot have only_test_set and test_on_train")
# TODO:
# - implement loading multiple categories
frame_file = os.path.join(
self.dataset_root, self.category, "frame_annotations.jgz"
)
sequence_file = os.path.join(
self.dataset_root, self.category, "sequence_annotations.jgz"
)
subset_lists_file = os.path.join(
self.dataset_root, self.category, "set_lists.json"
)
common_kwargs = {
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"path_manager": path_manager,
"dataset_root": self.dataset_root,
"limit_to": self.limit_to,
"limit_sequences_to": self.limit_sequences_to,
"load_point_clouds": self.load_point_clouds,
"mask_images": self.mask_images,
"mask_depths": self.mask_depths,
"path_manager": self.path_manager,
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
**aux_dataset_kwargs,
**self.aux_dataset_kwargs,
}
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
dataset_name,
test_on_train,
only_test_set,
self.get_task(),
self.test_on_train,
self.only_test_set,
)
# load the evaluation batches
task = dataset_name.split("_")[-1]
batch_indices_path = os.path.join(
dataset_root,
category,
f"eval_batches_{task}.json",
self.dataset_root,
self.category,
f"eval_batches_{self.task_str}.json",
)
if path_manager is not None:
batch_indices_path = path_manager.get_local_path(batch_indices_path)
if self.path_manager is not None:
batch_indices_path = self.path_manager.get_local_path(batch_indices_path)
if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
@ -181,25 +163,31 @@ def dataset_zoo(
with open(batch_indices_path, "r") as f:
eval_batch_index = json.load(f)
restrict_sequence_name = self.restrict_sequence_name
if task == "singlesequence":
assert (
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
), (
"Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'"
+ " training and evaluation."
)
assert len(restrict_sequence_name) == 0, (
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
if self.get_task() == Task.SINGLE_SEQUENCE:
if (
self.test_restrict_sequence_id is None
or self.test_restrict_sequence_id < 0
):
raise ValueError(
"Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'"
+ " training and evaluation."
)
if len(self.restrict_sequence_name) > 0:
raise ValueError(
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
# a sort-stable set() equivalent:
eval_batches_sequence_names = list(
{b[0][0]: None for b in eval_batch_index}.keys()
)
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
eval_sequence_name = eval_batches_sequence_names[
self.test_restrict_sequence_id
]
eval_batch_index = [
b for b in eval_batch_index if b[0][0] == eval_sequence_name
]
@ -207,14 +195,14 @@ def dataset_zoo(
restrict_sequence_name = [eval_sequence_name]
train_dataset = None
if not only_test_set:
if not self.only_test_set:
train_dataset = ImplicitronDataset(
n_frames_per_sequence=n_frames_per_sequence,
n_frames_per_sequence=self.n_frames_per_sequence,
subsets=set_names_mapping["train"],
pick_sequence=restrict_sequence_name,
**common_kwargs,
)
if test_on_train:
if self.test_on_train:
assert train_dataset is not None
val_dataset = test_dataset = train_dataset
else:
@ -237,29 +225,26 @@ def dataset_zoo(
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index
)
datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset)
datasets = DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if self.assert_single_seq:
# check there's only one sequence in all datasets
sequence_names = {
sequence_name
for dset in datasets.iter_datasets()
for sequence_name in dset.sequence_names()
}
if len(sequence_names) > 1:
raise ValueError("Multiple sequences loaded but expected one")
if assert_single_seq:
# check there's only one sequence in all datasets
sequence_names = {
sequence_name
for dset in datasets.iter_datasets()
for sequence_name in dset.sequence_names()
}
if len(sequence_names) > 1:
raise ValueError("Multiple sequences loaded but expected one")
return datasets
return datasets
enable_get_default_args(dataset_zoo)
def get_task(self) -> Task:
return Task(self.task_str)
def _get_co3d_set_names_mapping(
dataset_name: str,
task: Task,
test_on_train: bool,
only_test: bool,
) -> Dict[str, List[str]]:
@ -273,7 +258,7 @@ def _get_co3d_set_names_mapping(
- val (if not test_on_train)
- test (if not test_on_train)
"""
single_seq = dataset_name == "co3d_singlesequence"
single_seq = task == Task.SINGLE_SEQUENCE
if only_test:
set_names_mapping = {}

View File

@ -12,11 +12,12 @@ from typing import Any, cast, Dict, List, Optional, Tuple
import lpips
import torch
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,
@ -101,23 +102,21 @@ def evaluate_dbir_for_category(
torch.manual_seed(42)
dataset_name = {
Task.SINGLE_SEQUENCE: "co3d_singlesequence",
Task.MULTI_SEQUENCE: "co3d_multisequence",
}[task]
datasets = dataset_zoo(
category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == Task.SINGLE_SEQUENCE,
dataset_name=dataset_name,
test_on_train=False,
load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id,
path_manager=path_manager,
dataset_map_provider_args = {
"category": category,
"dataset_root": os.environ["CO3D_DATASET_ROOT"],
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
"task_str": task.value,
"test_on_train": False,
"load_point_clouds": True,
"test_restrict_sequence_id": single_sequence_id,
"path_manager": path_manager,
}
data_source = ImplicitronDataSource(
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
)
dataloaders = dataloader_zoo(datasets)
datasets, dataloaders = data_source.get_datasets_and_dataloaders()
test_dataset = datasets.test
test_dataloader = dataloaders.test

View File

@ -0,0 +1,33 @@
dataset_map_provider_class_type: ???
dataloader_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
dataset_root: ''
limit_to: -1
limit_sequences_to: -1
n_frames_per_sequence: -1
test_on_train: false
load_point_clouds: false
mask_images: false
mask_depths: false
restrict_sequence_name: []
test_restrict_sequence_id: -1
assert_single_seq: false
only_test_set: false
aux_dataset_kwargs:
box_crop: true
box_crop_context: 0.3
image_width: 800
image_height: 800
remove_empty_masks: true
path_manager: null

View File

@ -118,6 +118,6 @@ implicit_function_IdrFeatureField_args:
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
n_harmonic_functions_xyz: 1729
pooled_feature_dim: 0
encoding_dim: 0

View File

@ -70,6 +70,9 @@ class TestGenericModel(unittest.TestCase):
"AngleWeightedIdentityFeatureAggregator"
)
args.implicit_function_class_type = "IdrFeatureField"
idr_args = args.implicit_function_IdrFeatureField_args
idr_args.n_harmonic_functions_xyz = 1729
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
@ -78,6 +81,7 @@ class TestGenericModel(unittest.TestCase):
AngleWeightedIdentityFeatureAggregator,
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertEqual(gm._implicit_functions[0]._fn.n_harmonic_functions_xyz, 1729)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
self.assertFalse(hasattr(gm, "implicit_function"))

View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.tools.config import get_default_args
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
class TestDataSource(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_one(self):
cfg = get_default_args(ImplicitronDataSource)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG:
(DATA_DIR / "data_source.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text())