_allow_untyped for get_default_args

Summary:
ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1144

Reviewed By: shapovalov

Differential Revision: D35258561

fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
Jeremy Reizenstein 2022-03-31 06:31:45 -07:00 committed by Facebook GitHub Bot
parent a54ad2b912
commit 24260130ce
6 changed files with 46 additions and 7 deletions

View File

@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"):
@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int):
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
solver_args: DictConfig = get_default_args_field(
init_optimizer, _allow_untyped=True
)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"

View File

@ -57,6 +57,7 @@ def dataloader_zoo(
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
images_per_seq_options = tuple(images_per_seq_options)
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")

View File

@ -100,6 +100,8 @@ def dataset_zoo(
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
restrict_sequence_name = tuple(restrict_sequence_name)
aux_dataset_kwargs = dict(aux_dataset_kwargs)
datasets = {}

View File

@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
ray_tracer_args: DictConfig = get_default_args_field(
RayTracing, _allow_untyped=True
)
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork
RayNormalColoringNetwork, _allow_untyped=True
)
bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import itertools
@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
def get_default_args(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
) -> DictConfig:
"""
Get the DictConfig of args to call C - which might be a type or a function.
@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
Args:
C: the class or function to be processed
_allow_untyped: (internal use) If True, do not try to make the
output typed when it is not a Configurable or
ReplaceableBase. This avoids problems (due to local
dataclasses being remembered inside the returned
DictConfig and any of its DictConfig and ListConfig
members) when pickling the output, but will break
conversions of yaml strings to/from any emum members
of C.
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# regular class or function
field_annotations = []
kwargs = {}
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
_, annotation = _resolve_optional(defval.annotation)
kwargs[pname] = copy.deepcopy(default)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
if _allow_untyped:
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
@ -696,7 +713,9 @@ def expand_args_fields(
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
def get_default_args_field(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
):
"""
Get a dataclass field which defaults to get_default_args(...)
@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
return get_default_args(
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
)
return dataclasses.field(default_factory=create)

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pickle
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase):
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32