mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clarify expand_args_fields
Summary: Fix doc and add a call to expand_args_fields for each implicit function. Reviewed By: shapovalov Differential Revision: D35929811 fbshipit-source-id: 8c3cfa56b8d8908fd2165614960e3d34b54717bb
This commit is contained in:
parent
9e57b994ca
commit
90ab219d88
@ -15,6 +15,7 @@ import torch
|
||||
import tqdm
|
||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
@ -677,6 +678,7 @@ class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
implicit_function_type = registry.get(
|
||||
ImplicitFunctionBase, self.implicit_function_class_type
|
||||
)
|
||||
expand_args_fields(implicit_function_type)
|
||||
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
|
||||
raise ValueError(
|
||||
self.implicit_function_class_type
|
||||
|
@ -606,12 +606,14 @@ def expand_args_fields(
|
||||
"""
|
||||
This expands a class which inherits Configurable or ReplaceableBase classes,
|
||||
including dataclass processing. some_class is modified in place by this function.
|
||||
If expand_args_fields(some_class) has already been called, subsequent calls do
|
||||
nothing and return some_class unmodified.
|
||||
For classes of type ReplaceableBase, you can add some_class to the registry before
|
||||
or after calling this function. But potential inner classes need to be registered
|
||||
before this function is run on the outer class.
|
||||
|
||||
The transformations this function makes, before the concluding
|
||||
dataclasses.dataclass, are as follows. if X is a base class with registered
|
||||
dataclasses.dataclass, are as follows. If X is a base class with registered
|
||||
subclasses Y and Z, replace a class member
|
||||
|
||||
x: X
|
||||
@ -626,7 +628,9 @@ def expand_args_fields(
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
x_type = registry.get(X, self.x_class_type)
|
||||
expand_args_fields(x_type)
|
||||
self.x = x_type(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
x_class_type: str = "UNDEFAULTED"
|
||||
@ -651,7 +655,9 @@ def expand_args_fields(
|
||||
self.x = None
|
||||
return
|
||||
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
x_type = registry.get(X, self.x_class_type)
|
||||
expand_args_fields(x_type)
|
||||
self.x = x_type(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||
@ -670,6 +676,7 @@ def expand_args_fields(
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
def create_x(self):
|
||||
expand_args_fields(X)
|
||||
self.x = X(self.x_args)
|
||||
|
||||
Similarly, replace,
|
||||
@ -687,6 +694,7 @@ def expand_args_fields(
|
||||
x_enabled: bool = False
|
||||
def create_x(self):
|
||||
if self.x_enabled:
|
||||
expand_args_fields(X)
|
||||
self.x = X(self.x_args)
|
||||
else:
|
||||
self.x = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user