test fix for param_groups

Summary: param_groups only expected on MLPDecoder, not ElementwiseDecoder

Reviewed By: shapovalov

Differential Revision: D40508539

fbshipit-source-id: ea040ad6f7e26bd7d87e5de2eaadae2cf4b04faf
This commit is contained in:
Jeremy Reizenstein 2022-10-19 04:08:30 -07:00 committed by Facebook GitHub Bot
parent fe5bdb2fb5
commit 9535c576e0

View File

@ -43,27 +43,8 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
space and transforms it into the required quantity (for example density and color).
Members:
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with self key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=self and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
"""
param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self):
super().__init__()
@ -280,11 +261,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
class MLPDecoder(DecoderFunctionBase):
"""
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
If using Implicitron config system `input_dim` of the `network` is changed to the
value of `input_dim` member and `input_skips` is removed.
The `input_dim` of the `network` is set from the value of `input_dim` member.
Members:
input_dim: dimension of input.
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with self key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=self and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
network_args: configuration for MLPWithInputSkips
"""
input_dim: int = 3
param_groups: Dict[str, str] = field(default_factory=lambda: {})
network: MLPWithInputSkips
def __post_init__(self):