enable_get_default_args to allow pickling get_default_args(f)

Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce) didn't work.

When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.

Reviewed By: davnov134

Differential Revision: D35364410

fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
Jeremy Reizenstein 2022-04-06 03:32:31 -07:00 committed by Facebook GitHub Bot
parent 4c48beb226
commit e10a90140d
7 changed files with 102 additions and 50 deletions

View File

@ -74,6 +74,7 @@ from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as eval
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
enable_get_default_args,
get_default_args_field,
remove_unused_components,
)
@ -304,6 +305,9 @@ def init_optimizer(
return optimizer, scheduler
enable_get_default_args(init_optimizer)
def trainvalidate(
model,
stats,
@ -663,9 +667,7 @@ def _seed_all_random_engines(seed: int):
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(
init_optimizer, _allow_untyped=True
)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"

View File

@ -7,6 +7,7 @@
from typing import Dict, Sequence
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler
@ -56,8 +57,6 @@ def dataloader_zoo(
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
images_per_seq_options = tuple(images_per_seq_options)
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")
@ -96,3 +95,6 @@ def dataloader_zoo(
raise ValueError(f"Unsupported dataset: {dataset_name}")
return dataloaders
enable_get_default_args(dataloader_zoo)

View File

@ -11,6 +11,7 @@ import os
from typing import Any, Dict, List, Optional, Sequence
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
from .utils import (
@ -100,9 +101,6 @@ def dataset_zoo(
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
restrict_sequence_name = tuple(restrict_sequence_name)
aux_dataset_kwargs = dict(aux_dataset_kwargs)
datasets = {}
# TODO:
@ -222,6 +220,9 @@ def dataset_zoo(
return datasets
enable_get_default_args(dataset_zoo)
def _get_co3d_set_names_mapping(
dataset_name: str,
test_on_train: bool,

View File

@ -7,6 +7,7 @@ import logging
from typing import List, Tuple
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
@ -106,3 +107,6 @@ class RayNormalColoringNetwork(torch.nn.Module):
x = self.tanh(x)
return x
enable_get_default_args(RayNormalColoringNetwork)

View File

@ -20,11 +20,9 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(
RayTracing, _allow_untyped=True
)
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork, _allow_untyped=True
RayNormalColoringNetwork
)
bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0

View File

@ -4,10 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import itertools
import sys
import warnings
from collections import Counter, defaultdict
from enum import Enum
@ -27,8 +27,7 @@ Core functionality:
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
- get_default_args -- gets an omegaconf.DictConfig for initializing
a given class or calling a given function.
- get_default_args -- gets an omegaconf.DictConfig for initializing a given class.
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
@ -46,6 +45,7 @@ Additional utility functions:
- remove_unused_components -- used for simplifying a DictConfig instance.
- get_default_args_field -- default for DictConfig member of another configurable.
- enable_get_default_args -- Allows get_default_args on a function or plain class.
1. The simplest usage of this functionality is as follows. First a schema is defined
@ -413,34 +413,25 @@ def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
) -> DictConfig:
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
"""
Get the DictConfig of args to call C - which might be a type or a function.
Get the DictConfig corresponding to the defaults in a dataclass or
configurable. Normal use is to provide a dataclass can be provided as C.
If enable_get_default_args has been called on a function or plain class,
then that function or class can be provided as C.
If C is a subclass of Configurable or ReplaceableBase, we make sure
it has been processed with expand_args_fields. If C is a dataclass,
including a subclass of Configurable or ReplaceableBase, the output
will be a typed DictConfig.
it has been processed with expand_args_fields.
Args:
C: the class or function to be processed
_allow_untyped: (internal use) If True, do not try to make the
output typed when it is not a Configurable or
ReplaceableBase. This avoids problems (due to local
dataclasses being remembered inside the returned
DictConfig and any of its DictConfig and ListConfig
members) when pickling the output, but will break
conversions of yaml strings to/from any emum members
of C.
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
while it is already being processed.
Returns:
new DictConfig object
new DictConfig object, which is typed.
"""
if C is None:
return DictConfig({})
@ -471,9 +462,46 @@ def get_default_args(
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
# regular class or function
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")
dataclass_name = _dataclass_name_for_function(C)
dataclass = getattr(sys.modules[C.__module__], dataclass_name, None)
if dataclass is None:
raise ValueError(
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
)
return OmegaConf.structured(dataclass)
def _dataclass_name_for_function(C: Any) -> str:
"""
Returns the name of the dataclass which enable_get_default_args(C)
creates.
"""
name = f"_{C.__name__}_default_args_"
return name
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
"""
If C is a function or a plain class with an __init__ function,
and you want get_default_args(C) to work, then add
`enable_get_default_args(C)` straight after the definition of C.
This makes a dataclass corresponding to the default arguments of C
and stores it in the same module as C.
Args:
C: a function, or a class with an __init__ function. Must
have types for all its defaulted args.
overwrite: whether to allow calling this a second time on
the same function.
"""
if not inspect.isfunction(C) and not inspect.isclass(C):
raise ValueError(f"Unexpected {C}")
field_annotations = []
kwargs = {}
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
@ -488,8 +516,6 @@ def get_default_args(
_, annotation = _resolve_optional(defval.annotation)
kwargs[pname] = copy.deepcopy(default)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
@ -503,13 +529,16 @@ def get_default_args(
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
if _allow_untyped:
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
)
name = _dataclass_name_for_function(C)
module = sys.modules[C.__module__]
if hasattr(module, name):
if overwrite:
warnings.warn(f"Overwriting {name} in {C.__module__}.")
else:
raise ValueError(f"Cannot overwrite {name} in {C.__module__}.")
dc = dataclasses.make_dataclass(name, field_annotations)
dc.__module__ = C.__module__
setattr(module, name, dc)
def _params_iter(C):
@ -715,9 +744,7 @@ def expand_args_fields(
return some_class
def get_default_args_field(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
):
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
Get a dataclass field which defaults to get_default_args(...)
@ -729,9 +756,7 @@ def get_default_args_field(
"""
def create():
return get_default_args(
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
)
return get_default_args(C, _do_not_process=_do_not_process)
return dataclasses.field(default_factory=create)

View File

@ -19,6 +19,7 @@ from pytorch3d.implicitron.tools.config import (
_is_actually_dataclass,
_ProcessType,
_Registry,
enable_get_default_args,
expand_args_fields,
get_default_args,
get_default_args_field,
@ -236,6 +237,8 @@ class TestConfig(unittest.TestCase):
def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b
enable_get_default_args(Foo)
@dataclass()
class Bar:
aa: int = 9
@ -480,10 +483,14 @@ class TestConfig(unittest.TestCase):
def get_tuple(self):
return self.tuple_member
enable_get_default_args(SimpleClass)
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
enable_get_default_args(f)
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
@ -567,10 +574,14 @@ class TestConfig(unittest.TestCase):
def C_fn(a: A = A.B1):
pass
enable_get_default_args(C_fn)
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
enable_get_default_args(C_cl)
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
@ -586,14 +597,20 @@ class TestConfig(unittest.TestCase):
self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
def func(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
enable_get_default_args(func)
args = get_default_args(func)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
args_regenerated = get_default_args(func)
pickle.dumps(args_regenerated)
pickle.dumps(args)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
@ -674,6 +691,9 @@ class MockClassWithInit: # noqa: B903
self.field_reference_type = field_reference_type
enable_get_default_args(MockClassWithInit)
class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {