mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
enable_get_default_args to allow pickling get_default_args(f)
Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce
) didn't work.
When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.
Reviewed By: davnov134
Differential Revision: D35364410
fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
parent
4c48beb226
commit
e10a90140d
@ -74,6 +74,7 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
get_default_args_field,
|
||||
remove_unused_components,
|
||||
)
|
||||
@ -304,6 +305,9 @@ def init_optimizer(
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
|
||||
|
||||
def trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
@ -663,9 +667,7 @@ def _seed_all_random_engines(seed: int):
|
||||
@dataclass(eq=False)
|
||||
class ExperimentConfig:
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(
|
||||
init_optimizer, _allow_untyped=True
|
||||
)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
architecture: str = "generic"
|
||||
|
@ -7,6 +7,7 @@
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
@ -56,8 +57,6 @@ def dataloader_zoo(
|
||||
dataloaders: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
|
||||
images_per_seq_options = tuple(images_per_seq_options)
|
||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
@ -96,3 +95,6 @@ def dataloader_zoo(
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
return dataloaders
|
||||
|
||||
|
||||
enable_get_default_args(dataloader_zoo)
|
||||
|
@ -11,6 +11,7 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
|
||||
from .utils import (
|
||||
@ -100,9 +101,6 @@ def dataset_zoo(
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
"""
|
||||
restrict_sequence_name = tuple(restrict_sequence_name)
|
||||
aux_dataset_kwargs = dict(aux_dataset_kwargs)
|
||||
|
||||
datasets = {}
|
||||
|
||||
# TODO:
|
||||
@ -222,6 +220,9 @@ def dataset_zoo(
|
||||
return datasets
|
||||
|
||||
|
||||
enable_get_default_args(dataset_zoo)
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
dataset_name: str,
|
||||
test_on_train: bool,
|
||||
|
@ -7,6 +7,7 @@ import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
||||
from torch import nn
|
||||
|
||||
@ -106,3 +107,6 @@ class RayNormalColoringNetwork(torch.nn.Module):
|
||||
|
||||
x = self.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
enable_get_default_args(RayNormalColoringNetwork)
|
||||
|
@ -20,11 +20,9 @@ from .rgb_net import RayNormalColoringNetwork
|
||||
@registry.register
|
||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
render_features_dimensions: int = 3
|
||||
ray_tracer_args: DictConfig = get_default_args_field(
|
||||
RayTracing, _allow_untyped=True
|
||||
)
|
||||
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||
RayNormalColoringNetwork, _allow_untyped=True
|
||||
RayNormalColoringNetwork
|
||||
)
|
||||
bg_color: Tuple[float, ...] = (0.0,)
|
||||
soft_mask_alpha: float = 50.0
|
||||
|
@ -4,10 +4,10 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import itertools
|
||||
import sys
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
@ -27,8 +27,7 @@ Core functionality:
|
||||
|
||||
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
|
||||
|
||||
- get_default_args -- gets an omegaconf.DictConfig for initializing
|
||||
a given class or calling a given function.
|
||||
- get_default_args -- gets an omegaconf.DictConfig for initializing a given class.
|
||||
|
||||
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
|
||||
|
||||
@ -46,6 +45,7 @@ Additional utility functions:
|
||||
|
||||
- remove_unused_components -- used for simplifying a DictConfig instance.
|
||||
- get_default_args_field -- default for DictConfig member of another configurable.
|
||||
- enable_get_default_args -- Allows get_default_args on a function or plain class.
|
||||
|
||||
|
||||
1. The simplest usage of this functionality is as follows. First a schema is defined
|
||||
@ -413,34 +413,25 @@ def _is_configurable_class(C) -> bool:
|
||||
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
|
||||
|
||||
|
||||
def get_default_args(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
) -> DictConfig:
|
||||
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
|
||||
"""
|
||||
Get the DictConfig of args to call C - which might be a type or a function.
|
||||
Get the DictConfig corresponding to the defaults in a dataclass or
|
||||
configurable. Normal use is to provide a dataclass can be provided as C.
|
||||
If enable_get_default_args has been called on a function or plain class,
|
||||
then that function or class can be provided as C.
|
||||
|
||||
If C is a subclass of Configurable or ReplaceableBase, we make sure
|
||||
it has been processed with expand_args_fields. If C is a dataclass,
|
||||
including a subclass of Configurable or ReplaceableBase, the output
|
||||
will be a typed DictConfig.
|
||||
it has been processed with expand_args_fields.
|
||||
|
||||
Args:
|
||||
C: the class or function to be processed
|
||||
_allow_untyped: (internal use) If True, do not try to make the
|
||||
output typed when it is not a Configurable or
|
||||
ReplaceableBase. This avoids problems (due to local
|
||||
dataclasses being remembered inside the returned
|
||||
DictConfig and any of its DictConfig and ListConfig
|
||||
members) when pickling the output, but will break
|
||||
conversions of yaml strings to/from any emum members
|
||||
of C.
|
||||
_do_not_process: (internal use) When this function is called from
|
||||
expand_args_fields, we specify any class currently being
|
||||
processed, to make sure we don't try to process a class
|
||||
while it is already being processed.
|
||||
|
||||
Returns:
|
||||
new DictConfig object
|
||||
new DictConfig object, which is typed.
|
||||
"""
|
||||
if C is None:
|
||||
return DictConfig({})
|
||||
@ -471,9 +462,46 @@ def get_default_args(
|
||||
if _is_configurable_class(C):
|
||||
raise ValueError(f"Failed to process {C}")
|
||||
|
||||
# regular class or function
|
||||
if not inspect.isfunction(C) and not inspect.isclass(C):
|
||||
raise ValueError(f"Unexpected {C}")
|
||||
|
||||
dataclass_name = _dataclass_name_for_function(C)
|
||||
dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)
|
||||
if dataclass is None:
|
||||
raise ValueError(
|
||||
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
|
||||
)
|
||||
|
||||
return OmegaConf.structured(dataclass)
|
||||
|
||||
|
||||
def _dataclass_name_for_function(C: Any) -> str:
|
||||
"""
|
||||
Returns the name of the dataclass which enable_get_default_args(C)
|
||||
creates.
|
||||
"""
|
||||
name = f"_{C.__name__}_default_args_"
|
||||
return name
|
||||
|
||||
|
||||
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
|
||||
"""
|
||||
If C is a function or a plain class with an __init__ function,
|
||||
and you want get_default_args(C) to work, then add
|
||||
`enable_get_default_args(C)` straight after the definition of C.
|
||||
This makes a dataclass corresponding to the default arguments of C
|
||||
and stores it in the same module as C.
|
||||
|
||||
Args:
|
||||
C: a function, or a class with an __init__ function. Must
|
||||
have types for all its defaulted args.
|
||||
overwrite: whether to allow calling this a second time on
|
||||
the same function.
|
||||
"""
|
||||
if not inspect.isfunction(C) and not inspect.isclass(C):
|
||||
raise ValueError(f"Unexpected {C}")
|
||||
|
||||
field_annotations = []
|
||||
kwargs = {}
|
||||
for pname, defval in _params_iter(C):
|
||||
default = defval.default
|
||||
if default == inspect.Parameter.empty:
|
||||
@ -488,8 +516,6 @@ def get_default_args(
|
||||
|
||||
_, annotation = _resolve_optional(defval.annotation)
|
||||
|
||||
kwargs[pname] = copy.deepcopy(default)
|
||||
|
||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||
default = tuple(default)
|
||||
|
||||
@ -503,13 +529,16 @@ def get_default_args(
|
||||
field_ = dataclasses.field(default=default)
|
||||
field_annotations.append((pname, defval.annotation, field_))
|
||||
|
||||
if _allow_untyped:
|
||||
return DictConfig(kwargs)
|
||||
|
||||
# make a temp dataclass and generate a structured config from it.
|
||||
return OmegaConf.structured(
|
||||
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
|
||||
)
|
||||
name = _dataclass_name_for_function(C)
|
||||
module = sys.modules[C.__module__]
|
||||
if hasattr(module, name):
|
||||
if overwrite:
|
||||
warnings.warn(f"Overwriting {name} in {C.__module__}.")
|
||||
else:
|
||||
raise ValueError(f"Cannot overwrite {name} in {C.__module__}.")
|
||||
dc = dataclasses.make_dataclass(name, field_annotations)
|
||||
dc.__module__ = C.__module__
|
||||
setattr(module, name, dc)
|
||||
|
||||
|
||||
def _params_iter(C):
|
||||
@ -715,9 +744,7 @@ def expand_args_fields(
|
||||
return some_class
|
||||
|
||||
|
||||
def get_default_args_field(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
):
|
||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
"""
|
||||
Get a dataclass field which defaults to get_default_args(...)
|
||||
|
||||
@ -729,9 +756,7 @@ def get_default_args_field(
|
||||
"""
|
||||
|
||||
def create():
|
||||
return get_default_args(
|
||||
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
|
||||
)
|
||||
return get_default_args(C, _do_not_process=_do_not_process)
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
_is_actually_dataclass,
|
||||
_ProcessType,
|
||||
_Registry,
|
||||
enable_get_default_args,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
get_default_args_field,
|
||||
@ -236,6 +237,8 @@ class TestConfig(unittest.TestCase):
|
||||
def __init__(self, a: Any = 1, b: Any = 2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
enable_get_default_args(Foo)
|
||||
|
||||
@dataclass()
|
||||
class Bar:
|
||||
aa: int = 9
|
||||
@ -480,10 +483,14 @@ class TestConfig(unittest.TestCase):
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
|
||||
enable_get_default_args(SimpleClass)
|
||||
|
||||
def f(*, a: int = 3, b: str = "kj"):
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(b, "kj")
|
||||
|
||||
enable_get_default_args(f)
|
||||
|
||||
class C(Configurable):
|
||||
simple: DictConfig = get_default_args_field(SimpleClass)
|
||||
# simple2: SimpleClass2 = SimpleClass2()
|
||||
@ -567,10 +574,14 @@ class TestConfig(unittest.TestCase):
|
||||
def C_fn(a: A = A.B1):
|
||||
pass
|
||||
|
||||
enable_get_default_args(C_fn)
|
||||
|
||||
class C_cl:
|
||||
def __init__(self, a: A = A.B1) -> None:
|
||||
pass
|
||||
|
||||
enable_get_default_args(C_cl)
|
||||
|
||||
for C_ in [C, C_fn, C_cl]:
|
||||
base = get_default_args(C_)
|
||||
self.assertEqual(base.a, A.B1)
|
||||
@ -586,14 +597,20 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_pickle(self):
|
||||
def f(a: int = 1, b: str = "3"):
|
||||
def func(a: int = 1, b: str = "3"):
|
||||
pass
|
||||
|
||||
args = get_default_args(f, _allow_untyped=True)
|
||||
enable_get_default_args(func)
|
||||
|
||||
args = get_default_args(func)
|
||||
args2 = pickle.loads(pickle.dumps(args))
|
||||
self.assertEqual(args2.a, 1)
|
||||
self.assertEqual(args2.b, "3")
|
||||
|
||||
args_regenerated = get_default_args(func)
|
||||
pickle.dumps(args_regenerated)
|
||||
pickle.dumps(args)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 32
|
||||
@ -674,6 +691,9 @@ class MockClassWithInit: # noqa: B903
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
enable_get_default_args(MockClassWithInit)
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._instances = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user