diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index fb7494a7..eef3481a 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -43,27 +43,8 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): """ Decoding function is a torch.nn.Module which takes the embedding of a location in space and transforms it into the required quantity (for example density and color). - - Members: - param_groups: dictionary where keys are names of individual parameters - or module members and values are the parameter group where the - parameter/member will be sorted to. "self" key is used to denote the - parameter group at the module level. Possible keys, including the "self" key - do not have to be defined. By default all parameters are put into "default" - parameter group and have the learning rate defined in the optimizer, - it can be overridden at the: - - module level with “self” key, all the parameters and child - module's parameters will be put to that parameter group - - member level, which is the same as if the `param_groups` in that - member has key=“self” and value equal to that parameter group. - This is useful if members do not have `param_groups`, for - example torch.nn.Linear. - - parameter level, parameter with the same name as the key - will be put to that parameter group. """ - param_groups: Dict[str, str] = field(default_factory=lambda: {}) - def __post_init__(self): super().__init__() @@ -280,11 +261,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): class MLPDecoder(DecoderFunctionBase): """ Decoding function which uses `MLPWithIputSkips` to convert the embedding to output. - If using Implicitron config system `input_dim` of the `network` is changed to the - value of `input_dim` member and `input_skips` is removed. + The `input_dim` of the `network` is set from the value of `input_dim` member. + + Members: + input_dim: dimension of input. + param_groups: dictionary where keys are names of individual parameters + or module members and values are the parameter group where the + parameter/member will be sorted to. "self" key is used to denote the + parameter group at the module level. Possible keys, including the "self" key + do not have to be defined. By default all parameters are put into "default" + parameter group and have the learning rate defined in the optimizer, + it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will be put to that parameter group + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, parameter with the same name as the key + will be put to that parameter group. + network_args: configuration for MLPWithInputSkips """ input_dim: int = 3 + param_groups: Dict[str, str] = field(default_factory=lambda: {}) network: MLPWithInputSkips def __post_init__(self):