mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
_allow_untyped for get_default_args
Summary: ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases. Fixes https://github.com/facebookresearch/pytorch3d/issues/1144 Reviewed By: shapovalov Differential Revision: D35258561 fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a54ad2b912
commit
24260130ce
@@ -57,6 +57,7 @@ def dataloader_zoo(
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
|
||||
images_per_seq_options = tuple(images_per_seq_options)
|
||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
|
||||
@@ -100,6 +100,8 @@ def dataset_zoo(
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
"""
|
||||
restrict_sequence_name = tuple(restrict_sequence_name)
|
||||
aux_dataset_kwargs = dict(aux_dataset_kwargs)
|
||||
|
||||
datasets = {}
|
||||
|
||||
|
||||
@@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
|
||||
@registry.register
|
||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
render_features_dimensions: int = 3
|
||||
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
||||
ray_tracer_args: DictConfig = get_default_args_field(
|
||||
RayTracing, _allow_untyped=True
|
||||
)
|
||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||
RayNormalColoringNetwork
|
||||
RayNormalColoringNetwork, _allow_untyped=True
|
||||
)
|
||||
bg_color: Tuple[float, ...] = (0.0,)
|
||||
soft_mask_alpha: float = 50.0
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import itertools
|
||||
@@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
|
||||
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
|
||||
|
||||
|
||||
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
|
||||
def get_default_args(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
) -> DictConfig:
|
||||
"""
|
||||
Get the DictConfig of args to call C - which might be a type or a function.
|
||||
|
||||
@@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
Args:
|
||||
C: the class or function to be processed
|
||||
_allow_untyped: (internal use) If True, do not try to make the
|
||||
output typed when it is not a Configurable or
|
||||
ReplaceableBase. This avoids problems (due to local
|
||||
dataclasses being remembered inside the returned
|
||||
DictConfig and any of its DictConfig and ListConfig
|
||||
members) when pickling the output, but will break
|
||||
conversions of yaml strings to/from any emum members
|
||||
of C.
|
||||
_do_not_process: (internal use) When this function is called from
|
||||
expand_args_fields, we specify any class currently being
|
||||
processed, to make sure we don't try to process a class
|
||||
@@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
# regular class or function
|
||||
field_annotations = []
|
||||
kwargs = {}
|
||||
for pname, defval in _params_iter(C):
|
||||
default = defval.default
|
||||
if default == inspect.Parameter.empty:
|
||||
@@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
_, annotation = _resolve_optional(defval.annotation)
|
||||
|
||||
kwargs[pname] = copy.deepcopy(default)
|
||||
|
||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||
default = tuple(default)
|
||||
|
||||
@@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
field_ = dataclasses.field(default=default)
|
||||
field_annotations.append((pname, defval.annotation, field_))
|
||||
|
||||
if _allow_untyped:
|
||||
return DictConfig(kwargs)
|
||||
|
||||
# make a temp dataclass and generate a structured config from it.
|
||||
return OmegaConf.structured(
|
||||
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
|
||||
@@ -696,7 +713,9 @@ def expand_args_fields(
|
||||
return some_class
|
||||
|
||||
|
||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
def get_default_args_field(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
):
|
||||
"""
|
||||
Get a dataclass field which defaults to get_default_args(...)
|
||||
|
||||
@@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
"""
|
||||
|
||||
def create():
|
||||
return get_default_args(C, _do_not_process=_do_not_process)
|
||||
return get_default_args(
|
||||
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
|
||||
)
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user