diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index ae7a00b1..7e1b1a3b 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo from pytorch3d.implicitron.dataset.implicitron_dataset import ( - ImplicitronDataset, FrameData, + ImplicitronDataset, ) from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel @@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.renderer.cameras import CamerasBase + logger = logging.getLogger(__name__) if version.parse(hydra.__version__) < version.Version("1.1"): @@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int): @dataclass(eq=False) class ExperimentConfig: generic_model_args: DictConfig = get_default_args_field(GenericModel) - solver_args: DictConfig = get_default_args_field(init_optimizer) + solver_args: DictConfig = get_default_args_field( + init_optimizer, _allow_untyped=True + ) dataset_args: DictConfig = get_default_args_field(dataset_zoo) dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) architecture: str = "generic" diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py index a49a815b..2a75f7a5 100644 --- a/pytorch3d/implicitron/dataset/dataloader_zoo.py +++ b/pytorch3d/implicitron/dataset/dataloader_zoo.py @@ -57,6 +57,7 @@ def dataloader_zoo( `"dataset_subset_name": torch_dataloader_object` key, value pairs. """ + images_per_seq_options = tuple(images_per_seq_options) if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]: raise ValueError(f"Unsupported dataset: {dataset_name}") diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/dataset_zoo.py index cf96cf6a..1b16d817 100644 --- a/pytorch3d/implicitron/dataset/dataset_zoo.py +++ b/pytorch3d/implicitron/dataset/dataset_zoo.py @@ -100,6 +100,8 @@ def dataset_zoo( datasets: A dictionary containing the `"dataset_subset_name": torch_dataset_object` key, value pairs. """ + restrict_sequence_name = tuple(restrict_sequence_name) + aux_dataset_kwargs = dict(aux_dataset_kwargs) datasets = {} diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index 735a8a8c..a969029a 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork @registry.register class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): render_features_dimensions: int = 3 - ray_tracer_args: DictConfig = get_default_args_field(RayTracing) + ray_tracer_args: DictConfig = get_default_args_field( + RayTracing, _allow_untyped=True + ) ray_normal_coloring_network_args: DictConfig = get_default_args_field( - RayNormalColoringNetwork + RayNormalColoringNetwork, _allow_untyped=True ) bg_color: Tuple[float, ...] = (0.0,) soft_mask_alpha: float = 50.0 diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 79ae2eb7..5e7cb2bf 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import dataclasses import inspect import itertools @@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool: return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase)) -def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig: +def get_default_args( + C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = () +) -> DictConfig: """ Get the DictConfig of args to call C - which might be a type or a function. @@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig Args: C: the class or function to be processed + _allow_untyped: (internal use) If True, do not try to make the + output typed when it is not a Configurable or + ReplaceableBase. This avoids problems (due to local + dataclasses being remembered inside the returned + DictConfig and any of its DictConfig and ListConfig + members) when pickling the output, but will break + conversions of yaml strings to/from any emum members + of C. _do_not_process: (internal use) When this function is called from expand_args_fields, we specify any class currently being processed, to make sure we don't try to process a class @@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig # regular class or function field_annotations = [] + kwargs = {} for pname, defval in _params_iter(C): default = defval.default if default == inspect.Parameter.empty: @@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig _, annotation = _resolve_optional(defval.annotation) + kwargs[pname] = copy.deepcopy(default) + if isinstance(default, set): # force OmegaConf to convert it to ListConfig default = tuple(default) @@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig field_ = dataclasses.field(default=default) field_annotations.append((pname, defval.annotation, field_)) + if _allow_untyped: + return DictConfig(kwargs) + # make a temp dataclass and generate a structured config from it. return OmegaConf.structured( dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations) @@ -696,7 +713,9 @@ def expand_args_fields( return some_class -def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): +def get_default_args_field( + C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = () +): """ Get a dataclass field which defaults to get_default_args(...) @@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): """ def create(): - return get_default_args(C, _do_not_process=_do_not_process) + return get_default_args( + C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process + ) return dataclasses.field(default_factory=create) diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 9f0cc77f..fc3ccde0 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import pickle import textwrap import unittest from dataclasses import dataclass, field, is_dataclass @@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase): remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) self.assertEqual(remerged.a, A.B1) + def test_pickle(self): + def f(a: int = 1, b: str = "3"): + pass + + args = get_default_args(f, _allow_untyped=True) + args2 = pickle.loads(pickle.dumps(args)) + self.assertEqual(args2.a, 1) + self.assertEqual(args2.b, "3") + def test_remove_unused_components(self): struct = get_default_args(MainTest) struct.n_ids = 32