diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 7e1b1a3b..972b44c0 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -74,6 +74,7 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools.config import ( + enable_get_default_args, get_default_args_field, remove_unused_components, ) @@ -304,6 +305,9 @@ def init_optimizer( return optimizer, scheduler +enable_get_default_args(init_optimizer) + + def trainvalidate( model, stats, @@ -663,9 +667,7 @@ def _seed_all_random_engines(seed: int): @dataclass(eq=False) class ExperimentConfig: generic_model_args: DictConfig = get_default_args_field(GenericModel) - solver_args: DictConfig = get_default_args_field( - init_optimizer, _allow_untyped=True - ) + solver_args: DictConfig = get_default_args_field(init_optimizer) dataset_args: DictConfig = get_default_args_field(dataset_zoo) dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) architecture: str = "generic" diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py index 2a75f7a5..b5812e25 100644 --- a/pytorch3d/implicitron/dataset/dataloader_zoo.py +++ b/pytorch3d/implicitron/dataset/dataloader_zoo.py @@ -7,6 +7,7 @@ from typing import Dict, Sequence import torch +from pytorch3d.implicitron.tools.config import enable_get_default_args from .implicitron_dataset import FrameData, ImplicitronDatasetBase from .scene_batch_sampler import SceneBatchSampler @@ -56,8 +57,6 @@ def dataloader_zoo( dataloaders: A dictionary containing the `"dataset_subset_name": torch_dataloader_object` key, value pairs. """ - - images_per_seq_options = tuple(images_per_seq_options) if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]: raise ValueError(f"Unsupported dataset: {dataset_name}") @@ -96,3 +95,6 @@ def dataloader_zoo( raise ValueError(f"Unsupported dataset: {dataset_name}") return dataloaders + + +enable_get_default_args(dataloader_zoo) diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/dataset_zoo.py index 1b16d817..d93bfb0d 100644 --- a/pytorch3d/implicitron/dataset/dataset_zoo.py +++ b/pytorch3d/implicitron/dataset/dataset_zoo.py @@ -11,6 +11,7 @@ import os from typing import Any, Dict, List, Optional, Sequence from iopath.common.file_io import PathManager +from pytorch3d.implicitron.tools.config import enable_get_default_args from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase from .utils import ( @@ -100,9 +101,6 @@ def dataset_zoo( datasets: A dictionary containing the `"dataset_subset_name": torch_dataset_object` key, value pairs. """ - restrict_sequence_name = tuple(restrict_sequence_name) - aux_dataset_kwargs = dict(aux_dataset_kwargs) - datasets = {} # TODO: @@ -222,6 +220,9 @@ def dataset_zoo( return datasets +enable_get_default_args(dataset_zoo) + + def _get_co3d_set_names_mapping( dataset_name: str, test_on_train: bool, diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py index 68c77390..e1b81587 100644 --- a/pytorch3d/implicitron/models/renderer/rgb_net.py +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -7,6 +7,7 @@ import logging from typing import List, Tuple import torch +from pytorch3d.implicitron.tools.config import enable_get_default_args from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from torch import nn @@ -106,3 +107,6 @@ class RayNormalColoringNetwork(torch.nn.Module): x = self.tanh(x) return x + + +enable_get_default_args(RayNormalColoringNetwork) diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index a969029a..735a8a8c 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -20,11 +20,9 @@ from .rgb_net import RayNormalColoringNetwork @registry.register class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): render_features_dimensions: int = 3 - ray_tracer_args: DictConfig = get_default_args_field( - RayTracing, _allow_untyped=True - ) + ray_tracer_args: DictConfig = get_default_args_field(RayTracing) ray_normal_coloring_network_args: DictConfig = get_default_args_field( - RayNormalColoringNetwork, _allow_untyped=True + RayNormalColoringNetwork ) bg_color: Tuple[float, ...] = (0.0,) soft_mask_alpha: float = 50.0 diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 0883ca76..476c4711 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -4,10 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import copy import dataclasses import inspect import itertools +import sys import warnings from collections import Counter, defaultdict from enum import Enum @@ -27,8 +27,7 @@ Core functionality: - expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically. -- get_default_args -- gets an omegaconf.DictConfig for initializing - a given class or calling a given function. +- get_default_args -- gets an omegaconf.DictConfig for initializing a given class. - run_auto_creation -- Initialises nested members. To be called in __post_init__. @@ -46,6 +45,7 @@ Additional utility functions: - remove_unused_components -- used for simplifying a DictConfig instance. - get_default_args_field -- default for DictConfig member of another configurable. +- enable_get_default_args -- Allows get_default_args on a function or plain class. 1. The simplest usage of this functionality is as follows. First a schema is defined @@ -413,34 +413,25 @@ def _is_configurable_class(C) -> bool: return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase)) -def get_default_args( - C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = () -) -> DictConfig: +def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig: """ - Get the DictConfig of args to call C - which might be a type or a function. + Get the DictConfig corresponding to the defaults in a dataclass or + configurable. Normal use is to provide a dataclass can be provided as C. + If enable_get_default_args has been called on a function or plain class, + then that function or class can be provided as C. If C is a subclass of Configurable or ReplaceableBase, we make sure - it has been processed with expand_args_fields. If C is a dataclass, - including a subclass of Configurable or ReplaceableBase, the output - will be a typed DictConfig. + it has been processed with expand_args_fields. Args: C: the class or function to be processed - _allow_untyped: (internal use) If True, do not try to make the - output typed when it is not a Configurable or - ReplaceableBase. This avoids problems (due to local - dataclasses being remembered inside the returned - DictConfig and any of its DictConfig and ListConfig - members) when pickling the output, but will break - conversions of yaml strings to/from any emum members - of C. _do_not_process: (internal use) When this function is called from expand_args_fields, we specify any class currently being processed, to make sure we don't try to process a class while it is already being processed. Returns: - new DictConfig object + new DictConfig object, which is typed. """ if C is None: return DictConfig({}) @@ -471,9 +462,46 @@ def get_default_args( if _is_configurable_class(C): raise ValueError(f"Failed to process {C}") - # regular class or function + if not inspect.isfunction(C) and not inspect.isclass(C): + raise ValueError(f"Unexpected {C}") + + dataclass_name = _dataclass_name_for_function(C) + dataclass = getattr(sys.modules[C.__module__], dataclass_name, None) + if dataclass is None: + raise ValueError( + f"Cannot get args for {C}. Was enable_get_default_args forgotten?" + ) + + return OmegaConf.structured(dataclass) + + +def _dataclass_name_for_function(C: Any) -> str: + """ + Returns the name of the dataclass which enable_get_default_args(C) + creates. + """ + name = f"_{C.__name__}_default_args_" + return name + + +def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None: + """ + If C is a function or a plain class with an __init__ function, + and you want get_default_args(C) to work, then add + `enable_get_default_args(C)` straight after the definition of C. + This makes a dataclass corresponding to the default arguments of C + and stores it in the same module as C. + + Args: + C: a function, or a class with an __init__ function. Must + have types for all its defaulted args. + overwrite: whether to allow calling this a second time on + the same function. + """ + if not inspect.isfunction(C) and not inspect.isclass(C): + raise ValueError(f"Unexpected {C}") + field_annotations = [] - kwargs = {} for pname, defval in _params_iter(C): default = defval.default if default == inspect.Parameter.empty: @@ -488,8 +516,6 @@ def get_default_args( _, annotation = _resolve_optional(defval.annotation) - kwargs[pname] = copy.deepcopy(default) - if isinstance(default, set): # force OmegaConf to convert it to ListConfig default = tuple(default) @@ -503,13 +529,16 @@ def get_default_args( field_ = dataclasses.field(default=default) field_annotations.append((pname, defval.annotation, field_)) - if _allow_untyped: - return DictConfig(kwargs) - - # make a temp dataclass and generate a structured config from it. - return OmegaConf.structured( - dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations) - ) + name = _dataclass_name_for_function(C) + module = sys.modules[C.__module__] + if hasattr(module, name): + if overwrite: + warnings.warn(f"Overwriting {name} in {C.__module__}.") + else: + raise ValueError(f"Cannot overwrite {name} in {C.__module__}.") + dc = dataclasses.make_dataclass(name, field_annotations) + dc.__module__ = C.__module__ + setattr(module, name, dc) def _params_iter(C): @@ -715,9 +744,7 @@ def expand_args_fields( return some_class -def get_default_args_field( - C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = () -): +def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): """ Get a dataclass field which defaults to get_default_args(...) @@ -729,9 +756,7 @@ def get_default_args_field( """ def create(): - return get_default_args( - C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process - ) + return get_default_args(C, _do_not_process=_do_not_process) return dataclasses.field(default_factory=create) diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py index 54c9d7f6..e53fe0bd 100644 --- a/tests/implicitron/test_config.py +++ b/tests/implicitron/test_config.py @@ -19,6 +19,7 @@ from pytorch3d.implicitron.tools.config import ( _is_actually_dataclass, _ProcessType, _Registry, + enable_get_default_args, expand_args_fields, get_default_args, get_default_args_field, @@ -236,6 +237,8 @@ class TestConfig(unittest.TestCase): def __init__(self, a: Any = 1, b: Any = 2): self.a, self.b = a, b + enable_get_default_args(Foo) + @dataclass() class Bar: aa: int = 9 @@ -480,10 +483,14 @@ class TestConfig(unittest.TestCase): def get_tuple(self): return self.tuple_member + enable_get_default_args(SimpleClass) + def f(*, a: int = 3, b: str = "kj"): self.assertEqual(a, 3) self.assertEqual(b, "kj") + enable_get_default_args(f) + class C(Configurable): simple: DictConfig = get_default_args_field(SimpleClass) # simple2: SimpleClass2 = SimpleClass2() @@ -567,10 +574,14 @@ class TestConfig(unittest.TestCase): def C_fn(a: A = A.B1): pass + enable_get_default_args(C_fn) + class C_cl: def __init__(self, a: A = A.B1) -> None: pass + enable_get_default_args(C_cl) + for C_ in [C, C_fn, C_cl]: base = get_default_args(C_) self.assertEqual(base.a, A.B1) @@ -586,14 +597,20 @@ class TestConfig(unittest.TestCase): self.assertEqual(remerged.a, A.B1) def test_pickle(self): - def f(a: int = 1, b: str = "3"): + def func(a: int = 1, b: str = "3"): pass - args = get_default_args(f, _allow_untyped=True) + enable_get_default_args(func) + + args = get_default_args(func) args2 = pickle.loads(pickle.dumps(args)) self.assertEqual(args2.a, 1) self.assertEqual(args2.b, "3") + args_regenerated = get_default_args(func) + pickle.dumps(args_regenerated) + pickle.dumps(args) + def test_remove_unused_components(self): struct = get_default_args(MainTest) struct.n_ids = 32 @@ -674,6 +691,9 @@ class MockClassWithInit: # noqa: B903 self.field_reference_type = field_reference_type +enable_get_default_args(MockClassWithInit) + + class TestRawClasses(unittest.TestCase): def setUp(self) -> None: self._instances = {