mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
enable_get_default_args to allow pickling get_default_args(f)
Summary:
Try again to solve https://github.com/facebookresearch/pytorch3d/issues/1144 pickling problem.
D35258561 (24260130ce) didn't work.
When writing a function or vanilla class C which you want people to be able to call get_default_args on, you must add the line enable_get_default_args(C) to it. This causes autogeneration of a hidden dataclass in the module.
Reviewed By: davnov134
Differential Revision: D35364410
fbshipit-source-id: 53f6e6fff43e7142ae18ca3b06de7d0c849ef965
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4c48beb226
commit
e10a90140d
@@ -19,6 +19,7 @@ from pytorch3d.implicitron.tools.config import (
|
||||
_is_actually_dataclass,
|
||||
_ProcessType,
|
||||
_Registry,
|
||||
enable_get_default_args,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
get_default_args_field,
|
||||
@@ -236,6 +237,8 @@ class TestConfig(unittest.TestCase):
|
||||
def __init__(self, a: Any = 1, b: Any = 2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
enable_get_default_args(Foo)
|
||||
|
||||
@dataclass()
|
||||
class Bar:
|
||||
aa: int = 9
|
||||
@@ -480,10 +483,14 @@ class TestConfig(unittest.TestCase):
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
|
||||
enable_get_default_args(SimpleClass)
|
||||
|
||||
def f(*, a: int = 3, b: str = "kj"):
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(b, "kj")
|
||||
|
||||
enable_get_default_args(f)
|
||||
|
||||
class C(Configurable):
|
||||
simple: DictConfig = get_default_args_field(SimpleClass)
|
||||
# simple2: SimpleClass2 = SimpleClass2()
|
||||
@@ -567,10 +574,14 @@ class TestConfig(unittest.TestCase):
|
||||
def C_fn(a: A = A.B1):
|
||||
pass
|
||||
|
||||
enable_get_default_args(C_fn)
|
||||
|
||||
class C_cl:
|
||||
def __init__(self, a: A = A.B1) -> None:
|
||||
pass
|
||||
|
||||
enable_get_default_args(C_cl)
|
||||
|
||||
for C_ in [C, C_fn, C_cl]:
|
||||
base = get_default_args(C_)
|
||||
self.assertEqual(base.a, A.B1)
|
||||
@@ -586,14 +597,20 @@ class TestConfig(unittest.TestCase):
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_pickle(self):
|
||||
def f(a: int = 1, b: str = "3"):
|
||||
def func(a: int = 1, b: str = "3"):
|
||||
pass
|
||||
|
||||
args = get_default_args(f, _allow_untyped=True)
|
||||
enable_get_default_args(func)
|
||||
|
||||
args = get_default_args(func)
|
||||
args2 = pickle.loads(pickle.dumps(args))
|
||||
self.assertEqual(args2.a, 1)
|
||||
self.assertEqual(args2.b, "3")
|
||||
|
||||
args_regenerated = get_default_args(func)
|
||||
pickle.dumps(args_regenerated)
|
||||
pickle.dumps(args)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 32
|
||||
@@ -674,6 +691,9 @@ class MockClassWithInit: # noqa: B903
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
enable_get_default_args(MockClassWithInit)
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self._instances = {
|
||||
|
||||
Reference in New Issue
Block a user