From 90ab219d8817ee83a126595efdfbb2a669db39c2 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 13 May 2022 03:26:47 -0700 Subject: [PATCH] clarify expand_args_fields Summary: Fix doc and add a call to expand_args_fields for each implicit function. Reviewed By: shapovalov Differential Revision: D35929811 fbshipit-source-id: 8c3cfa56b8d8908fd2165614960e3d34b54717bb --- pytorch3d/implicitron/models/generic_model.py | 2 ++ pytorch3d/implicitron/tools/config.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 6311c6d5..ffd767d5 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -15,6 +15,7 @@ import torch import tqdm from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools.config import ( + expand_args_fields, registry, run_auto_creation, ) @@ -677,6 +678,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 implicit_function_type = registry.get( ImplicitFunctionBase, self.implicit_function_class_type ) + expand_args_fields(implicit_function_type) if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes(): raise ValueError( self.implicit_function_class_type diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index f104007f..2e806d4d 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -606,12 +606,14 @@ def expand_args_fields( """ This expands a class which inherits Configurable or ReplaceableBase classes, including dataclass processing. some_class is modified in place by this function. + If expand_args_fields(some_class) has already been called, subsequent calls do + nothing and return some_class unmodified. For classes of type ReplaceableBase, you can add some_class to the registry before or after calling this function. But potential inner classes need to be registered before this function is run on the outer class. The transformations this function makes, before the concluding - dataclasses.dataclass, are as follows. if X is a base class with registered + dataclasses.dataclass, are as follows. If X is a base class with registered subclasses Y and Z, replace a class member x: X @@ -626,7 +628,9 @@ def expand_args_fields( x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y)) x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z)) def create_x(self): - self.x = registry.get(X, self.x_class_type)( + x_type = registry.get(X, self.x_class_type) + expand_args_fields(x_type) + self.x = x_type( **self.getattr(f"x_{self.x_class_type}_args) ) x_class_type: str = "UNDEFAULTED" @@ -651,7 +655,9 @@ def expand_args_fields( self.x = None return - self.x = registry.get(X, self.x_class_type)( + x_type = registry.get(X, self.x_class_type) + expand_args_fields(x_type) + self.x = x_type( **self.getattr(f"x_{self.x_class_type}_args) ) x_class_type: Optional[str] = "UNDEFAULTED" @@ -670,6 +676,7 @@ def expand_args_fields( x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X)) def create_x(self): + expand_args_fields(X) self.x = X(self.x_args) Similarly, replace, @@ -687,6 +694,7 @@ def expand_args_fields( x_enabled: bool = False def create_x(self): if self.x_enabled: + expand_args_fields(X) self.x = X(self.x_args) else: self.x = None