enable_get_default_args to allow pickling get_default_args(f)

Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce) didn't work.

When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.

Reviewed By: davnov134

Differential Revision: D35364410

fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
Jeremy Reizenstein
2022-04-06 03:32:31 -07:00
committed by Facebook GitHub Bot
parent 4c48beb226
commit e10a90140d
7 changed files with 102 additions and 50 deletions

View File

@@ -19,6 +19,7 @@ from pytorch3d.implicitron.tools.config import (
_is_actually_dataclass,
_ProcessType,
_Registry,
enable_get_default_args,
expand_args_fields,
get_default_args,
get_default_args_field,
@@ -236,6 +237,8 @@ class TestConfig(unittest.TestCase):
def __init__(self, a: Any = 1, b: Any = 2):
self.a, self.b = a, b
enable_get_default_args(Foo)
@dataclass()
class Bar:
aa: int = 9
@@ -480,10 +483,14 @@ class TestConfig(unittest.TestCase):
def get_tuple(self):
return self.tuple_member
enable_get_default_args(SimpleClass)
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
enable_get_default_args(f)
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
@@ -567,10 +574,14 @@ class TestConfig(unittest.TestCase):
def C_fn(a: A = A.B1):
pass
enable_get_default_args(C_fn)
class C_cl:
def __init__(self, a: A = A.B1) -> None:
pass
enable_get_default_args(C_cl)
for C_ in [C, C_fn, C_cl]:
base = get_default_args(C_)
self.assertEqual(base.a, A.B1)
@@ -586,14 +597,20 @@ class TestConfig(unittest.TestCase):
self.assertEqual(remerged.a, A.B1)
def test_pickle(self):
def f(a: int = 1, b: str = "3"):
def func(a: int = 1, b: str = "3"):
pass
args = get_default_args(f, _allow_untyped=True)
enable_get_default_args(func)
args = get_default_args(func)
args2 = pickle.loads(pickle.dumps(args))
self.assertEqual(args2.a, 1)
self.assertEqual(args2.b, "3")
args_regenerated = get_default_args(func)
pickle.dumps(args_regenerated)
pickle.dumps(args)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
@@ -674,6 +691,9 @@ class MockClassWithInit: # noqa: B903
self.field_reference_type = field_reference_type
enable_get_default_args(MockClassWithInit)
class TestRawClasses(unittest.TestCase):
def setUp(self) -> None:
self._instances = {