mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
test fix for param_groups
Summary: param_groups only expected on MLPDecoder, not ElementwiseDecoder Reviewed By: shapovalov Differential Revision: D40508539 fbshipit-source-id: ea040ad6f7e26bd7d87e5de2eaadae2cf4b04faf
This commit is contained in:
parent
fe5bdb2fb5
commit
9535c576e0
@ -43,27 +43,8 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
Decoding function is a torch.nn.Module which takes the embedding of a location in
|
||||||
space and transforms it into the required quantity (for example density and color).
|
space and transforms it into the required quantity (for example density and color).
|
||||||
|
|
||||||
Members:
|
|
||||||
param_groups: dictionary where keys are names of individual parameters
|
|
||||||
or module members and values are the parameter group where the
|
|
||||||
parameter/member will be sorted to. "self" key is used to denote the
|
|
||||||
parameter group at the module level. Possible keys, including the "self" key
|
|
||||||
do not have to be defined. By default all parameters are put into "default"
|
|
||||||
parameter group and have the learning rate defined in the optimizer,
|
|
||||||
it can be overridden at the:
|
|
||||||
- module level with “self” key, all the parameters and child
|
|
||||||
module's parameters will be put to that parameter group
|
|
||||||
- member level, which is the same as if the `param_groups` in that
|
|
||||||
member has key=“self” and value equal to that parameter group.
|
|
||||||
This is useful if members do not have `param_groups`, for
|
|
||||||
example torch.nn.Linear.
|
|
||||||
- parameter level, parameter with the same name as the key
|
|
||||||
will be put to that parameter group.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -280,11 +261,30 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
class MLPDecoder(DecoderFunctionBase):
|
class MLPDecoder(DecoderFunctionBase):
|
||||||
"""
|
"""
|
||||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||||
If using Implicitron config system `input_dim` of the `network` is changed to the
|
The `input_dim` of the `network` is set from the value of `input_dim` member.
|
||||||
value of `input_dim` member and `input_skips` is removed.
|
|
||||||
|
Members:
|
||||||
|
input_dim: dimension of input.
|
||||||
|
param_groups: dictionary where keys are names of individual parameters
|
||||||
|
or module members and values are the parameter group where the
|
||||||
|
parameter/member will be sorted to. "self" key is used to denote the
|
||||||
|
parameter group at the module level. Possible keys, including the "self" key
|
||||||
|
do not have to be defined. By default all parameters are put into "default"
|
||||||
|
parameter group and have the learning rate defined in the optimizer,
|
||||||
|
it can be overridden at the:
|
||||||
|
- module level with “self” key, all the parameters and child
|
||||||
|
module's parameters will be put to that parameter group
|
||||||
|
- member level, which is the same as if the `param_groups` in that
|
||||||
|
member has key=“self” and value equal to that parameter group.
|
||||||
|
This is useful if members do not have `param_groups`, for
|
||||||
|
example torch.nn.Linear.
|
||||||
|
- parameter level, parameter with the same name as the key
|
||||||
|
will be put to that parameter group.
|
||||||
|
network_args: configuration for MLPWithInputSkips
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_dim: int = 3
|
input_dim: int = 3
|
||||||
|
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
network: MLPWithInputSkips
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user