_allow_untyped for get_default_args

Summary:
ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1144

Reviewed By: shapovalov

Differential Revision: D35258561

fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
Jeremy Reizenstein 2022-03-31 06:31:45 -07:00 committed by Facebook GitHub Bot
parent a54ad2b912
commit 24260130ce
6 changed files with 46 additions and 7 deletions

View File

@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import ( from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData, FrameData,
ImplicitronDataset,
) )
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"): if version.parse(hydra.__version__) < version.Version("1.1"):
@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int):
@dataclass(eq=False) @dataclass(eq=False)
class ExperimentConfig: class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel) generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer) solver_args: DictConfig = get_default_args_field(
init_optimizer, _allow_untyped=True
)
dataset_args: DictConfig = get_default_args_field(dataset_zoo) dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic" architecture: str = "generic"

View File

@ -57,6 +57,7 @@ def dataloader_zoo(
`"dataset_subset_name": torch_dataloader_object` key, value pairs. `"dataset_subset_name": torch_dataloader_object` key, value pairs.
""" """
images_per_seq_options = tuple(images_per_seq_options)
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]: if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}") raise ValueError(f"Unsupported dataset: {dataset_name}")

View File

@ -100,6 +100,8 @@ def dataset_zoo(
datasets: A dictionary containing the datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs. `"dataset_subset_name": torch_dataset_object` key, value pairs.
""" """
restrict_sequence_name = tuple(restrict_sequence_name)
aux_dataset_kwargs = dict(aux_dataset_kwargs)
datasets = {} datasets = {}

View File

@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register @registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3 render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing) ray_tracer_args: DictConfig = get_default_args_field(
RayTracing, _allow_untyped=True
)
ray_normal_coloring_network_args: DictConfig = get_default_args_field( ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork RayNormalColoringNetwork, _allow_untyped=True
) )
bg_color: Tuple[float, ...] = (0.0,) bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0 soft_mask_alpha: float = 50.0

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import dataclasses import dataclasses
import inspect import inspect
import itertools import itertools
@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase)) return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig: def get_default_args(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
) -> DictConfig:
""" """
Get the DictConfig of args to call C - which might be a type or a function. Get the DictConfig of args to call C - which might be a type or a function.
@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
Args: Args:
C: the class or function to be processed C: the class or function to be processed
_allow_untyped: (internal use) If True, do not try to make the
output typed when it is not a Configurable or
ReplaceableBase. This avoids problems (due to local
dataclasses being remembered inside the returned
DictConfig and any of its DictConfig and ListConfig
members) when pickling the output, but will break
conversions of yaml strings to/from any emum members
of C.
_do_not_process: (internal use) When this function is called from _do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class processed, to make sure we don't try to process a class
@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# regular class or function # regular class or function
field_annotations = [] field_annotations = []
kwargs = {}
for pname, defval in _params_iter(C): for pname, defval in _params_iter(C):
default = defval.default default = defval.default
if default == inspect.Parameter.empty: if default == inspect.Parameter.empty:
@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
_, annotation = _resolve_optional(defval.annotation) _, annotation = _resolve_optional(defval.annotation)
kwargs[pname] = copy.deepcopy(default)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default) default = tuple(default)
@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
field_ = dataclasses.field(default=default) field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_)) field_annotations.append((pname, defval.annotation, field_))
if _allow_untyped:
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it. # make a temp dataclass and generate a structured config from it.
return OmegaConf.structured( return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations) dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
@ -696,7 +713,9 @@ def expand_args_fields(
return some_class return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): def get_default_args_field(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
):
""" """
Get a dataclass field which defaults to get_default_args(...) Get a dataclass field which defaults to get_default_args(...)
@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
""" """
def create(): def create():
return get_default_args(C, _do_not_process=_do_not_process) return get_default_args(
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
)
return dataclasses.field(default_factory=create) return dataclasses.field(default_factory=create)

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import pickle
import textwrap import textwrap
import unittest import unittest
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase):
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1) self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
def test_remove_unused_components(self): def test_remove_unused_components(self):
struct = get_default_args(MainTest) struct = get_default_args(MainTest)
struct.n_ids = 32 struct.n_ids = 32