mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
_allow_untyped for get_default_args
Summary: ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases. Fixes https://github.com/facebookresearch/pytorch3d/issues/1144 Reviewed By: shapovalov Differential Revision: D35258561 fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
parent
a54ad2b912
commit
24260130ce
@ -67,8 +67,8 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
ImplicitronDataset,
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
||||
@ -80,6 +80,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if version.parse(hydra.__version__) < version.Version("1.1"):
|
||||
@ -662,7 +663,9 @@ def _seed_all_random_engines(seed: int):
|
||||
@dataclass(eq=False)
|
||||
class ExperimentConfig:
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
solver_args: DictConfig = get_default_args_field(
|
||||
init_optimizer, _allow_untyped=True
|
||||
)
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
architecture: str = "generic"
|
||||
|
@ -57,6 +57,7 @@ def dataloader_zoo(
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
|
||||
images_per_seq_options = tuple(images_per_seq_options)
|
||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
|
@ -100,6 +100,8 @@ def dataset_zoo(
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
"""
|
||||
restrict_sequence_name = tuple(restrict_sequence_name)
|
||||
aux_dataset_kwargs = dict(aux_dataset_kwargs)
|
||||
|
||||
datasets = {}
|
||||
|
||||
|
@ -20,9 +20,11 @@ from .rgb_net import RayNormalColoringNetwork
|
||||
@registry.register
|
||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
render_features_dimensions: int = 3
|
||||
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
||||
ray_tracer_args: DictConfig = get_default_args_field(
|
||||
RayTracing, _allow_untyped=True
|
||||
)
|
||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||
RayNormalColoringNetwork
|
||||
RayNormalColoringNetwork, _allow_untyped=True
|
||||
)
|
||||
bg_color: Tuple[float, ...] = (0.0,)
|
||||
soft_mask_alpha: float = 50.0
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import itertools
|
||||
@ -412,7 +413,9 @@ def _is_configurable_class(C) -> bool:
|
||||
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
|
||||
|
||||
|
||||
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
|
||||
def get_default_args(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
) -> DictConfig:
|
||||
"""
|
||||
Get the DictConfig of args to call C - which might be a type or a function.
|
||||
|
||||
@ -423,6 +426,14 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
Args:
|
||||
C: the class or function to be processed
|
||||
_allow_untyped: (internal use) If True, do not try to make the
|
||||
output typed when it is not a Configurable or
|
||||
ReplaceableBase. This avoids problems (due to local
|
||||
dataclasses being remembered inside the returned
|
||||
DictConfig and any of its DictConfig and ListConfig
|
||||
members) when pickling the output, but will break
|
||||
conversions of yaml strings to/from any emum members
|
||||
of C.
|
||||
_do_not_process: (internal use) When this function is called from
|
||||
expand_args_fields, we specify any class currently being
|
||||
processed, to make sure we don't try to process a class
|
||||
@ -462,6 +473,7 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
# regular class or function
|
||||
field_annotations = []
|
||||
kwargs = {}
|
||||
for pname, defval in _params_iter(C):
|
||||
default = defval.default
|
||||
if default == inspect.Parameter.empty:
|
||||
@ -476,6 +488,8 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
|
||||
_, annotation = _resolve_optional(defval.annotation)
|
||||
|
||||
kwargs[pname] = copy.deepcopy(default)
|
||||
|
||||
if isinstance(default, set): # force OmegaConf to convert it to ListConfig
|
||||
default = tuple(default)
|
||||
|
||||
@ -489,6 +503,9 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
|
||||
field_ = dataclasses.field(default=default)
|
||||
field_annotations.append((pname, defval.annotation, field_))
|
||||
|
||||
if _allow_untyped:
|
||||
return DictConfig(kwargs)
|
||||
|
||||
# make a temp dataclass and generate a structured config from it.
|
||||
return OmegaConf.structured(
|
||||
dataclasses.make_dataclass(f"__{C.__name__}_default_args__", field_annotations)
|
||||
@ -696,7 +713,9 @@ def expand_args_fields(
|
||||
return some_class
|
||||
|
||||
|
||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
def get_default_args_field(
|
||||
C, *, _allow_untyped: bool = False, _do_not_process: Tuple[type, ...] = ()
|
||||
):
|
||||
"""
|
||||
Get a dataclass field which defaults to get_default_args(...)
|
||||
|
||||
@ -708,7 +727,9 @@ def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
"""
|
||||
|
||||
def create():
|
||||
return get_default_args(C, _do_not_process=_do_not_process)
|
||||
return get_default_args(
|
||||
C, _allow_untyped=_allow_untyped, _do_not_process=_do_not_process
|
||||
)
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import pickle
|
||||
import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase):
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_pickle(self):
|
||||
def f(a: int = 1, b: str = "3"):
|
||||
pass
|
||||
|
||||
args = get_default_args(f, _allow_untyped=True)
|
||||
args2 = pickle.loads(pickle.dumps(args))
|
||||
self.assertEqual(args2.a, 1)
|
||||
self.assertEqual(args2.b, "3")
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 32
|
||||
|
Loading…
x
Reference in New Issue
Block a user