_allow_untyped for get_default_args

Summary:
ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1144

Reviewed By: shapovalov

Differential Revision: D35258561

fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
Jeremy Reizenstein
2022-03-31 06:31:45 -07:00
committed by Facebook GitHub Bot
parent a54ad2b912
commit 24260130ce
6 changed files with 46 additions and 7 deletions

View File

@@ -57,6 +57,7 @@ def dataloader_zoo(
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
images_per_seq_options = tuple(images_per_seq_options)
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")

View File

@@ -100,6 +100,8 @@ def dataset_zoo(
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
restrict_sequence_name = tuple(restrict_sequence_name)
aux_dataset_kwargs = dict(aux_dataset_kwargs)
datasets = {}

View File

@@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
@registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
ray_tracer_args: DictConfig = get_default_args_field(
RayTracing, _allow_untyped=True
)
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork
RayNormalColoringNetwork, _allow_untyped=True
)
bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import itertools
@@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
def get_default_args(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
) -> DictConfig:
"""
Get the DictConfig of args to call C - which might be a type or a function.
@@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
Args:
C: the class or function to be processed
_allow_untyped: (internal use) If True, do not try to make the
output typed when it is not a Configurable or
ReplaceableBase. This avoids problems (due to local
dataclasses being remembered inside the returned
DictConfig and any of its DictConfig and ListConfig
members) when pickling the output, but will break
conversions of yaml strings to/from any emum members
of C.
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
@@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
# regular class or function
field_annotations = []
kwargs = {}
for pname, defval in _params_iter(C):
default = defval.default
if default == inspect.Parameter.empty:
@@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
_, annotation = _resolve_optional(defval.annotation)
kwargs[pname] = copy.deepcopy(default)
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
default = tuple(default)
@@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
field_ = dataclasses.field(default=default)
field_annotations.append((pname, defval.annotation, field_))
if _allow_untyped:
return DictConfig(kwargs)
# make a temp dataclass and generate a structured config from it.
return OmegaConf.structured(
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
@@ -696,7 +713,9 @@ def expand_args_fields(
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
def get_default_args_field(
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
):
"""
Get a dataclass field which defaults to get_default_args(...)
@@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
return get_default_args(
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
)
return dataclasses.field(default_factory=create)