different learning rate for different parts

Summary:
Adds the ability to have different learning rates for different parts of the model. The trainable parts of the implicitron have a new member

       param_groups: dictionary where keys are names of individual parameters,
            or module’s members and values are the parameter group where the
            parameter/member will be sorted to. "self" key is used to denote the
            parameter group at the module level. Possible keys, including the "self" key
            do not have to be defined. By default all parameters are put into "default"
            parameter group and have the learning rate defined in the optimizer,
            it can be overriden at the:
                - module level with “self” key, all the parameters and child
                    module s parameters will be put to that parameter group
                - member level, which is the same as if the `param_groups` in that
                    member has key=“self” and value equal to that parameter group.
                    This is useful if members do not have `param_groups`, for
                    example torch.nn.Linear.
                - parameter level, parameter with the same name as the key
                    will be put to that parameter group.

And in the optimizer factory, parameters and their learning rates are recursively gathered.

Reviewed By: shapovalov

Differential Revision: D40145802

fbshipit-source-id: 631c02b8d79ee1c0eb4c31e6e42dbd3d2882078a
This commit is contained in:
Jeremy Reizenstein
2022-10-18 15:58:18 -07:00
committed by Facebook GitHub Bot
parent a819ecb00b
commit fe5bdb2fb5
6 changed files with 293 additions and 5 deletions

View File

@@ -13,9 +13,10 @@ This file contains
"""
import logging
from dataclasses import field
from enum import Enum
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import torch
@@ -42,8 +43,27 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
"""
Decoding function is a torch.nn.Module which takes the embedding of a location in
space and transforms it into the required quantity (for example density and color).
Members:
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
"""
param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self):
super().__init__()

View File

@@ -808,6 +808,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
with mean=init_mean and std=init_std. Default 0.
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
will be saved as parameters and therefore be trainable. Default True.
param_groups: dictionary where keys are names of individual parameters
or module members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overridden at the:
- module level with “self” key, all the parameters and child
module's parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
"""
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
@@ -820,6 +835,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
init_mean: float = 0
hold_voxel_grid_as_parameters: bool = True
param_groups: Dict[str, str] = field(default_factory=lambda: {})
def __post_init__(self):
super().__init__()