_allow_untyped for get_default_args

Summary:
ListConfig and DictConfig members of get_default_args(X) when X is a callable will contain references to a temporary dataclass and therefore be unpicklable. Avoid this in a few cases.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1144

Reviewed By: shapovalov

Differential Revision: D35258561

fbshipit-source-id: e52186825f52accee9a899e466967a4ff71b3d25
This commit is contained in:
Jeremy Reizenstein
2022-03-31 06:31:45 -07:00
committed by Facebook GitHub Bot
parent a54ad2b912
commit 24260130ce
6 changed files with 46 additions and 7 deletions

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pickle
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
@@ -581,6 +582,15 @@ class TestConfig(unittest.TestCase):
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32