allow dots in param_groups

Summary:
Allow a module's param_group member to specify overrides to the param groups of its members or their members.
Also logging for param group assignments.

This allows defining `params.basis_matrix` in the param_groups of a voxel_grid.

Reviewed By: shapovalov

Differential Revision: D41080667

fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
Jeremy Reizenstein 2022-11-07 06:41:40 -08:00 committed by Facebook GitHub Bot
parent a1f2ded58a
commit 7be49bf46f
2 changed files with 56 additions and 22 deletions

View File

@ -258,15 +258,16 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
at the module level. Possible keys, including the "self" key do not have to
be defined. By default all parameters have the learning rate defined in the
optimizer. This can be overridden by setting the parameter group in `param_groups`
member of a specific module, it can be overridden at the:
- module level with self key, all the parameters and child
module's parameters will inherit it
- member level, which is the same as if the `param_groups` in that
member has key=self and value equal to that parameter group.
member of a specific module. Values are a parameter group name. The keys
specify what parameters will be affected as follows:
- self: All the parameters of the module and its child modules
- name of a parameter: A parameter with that name.
- name of a module member: All the parameters of the module and its
child modules.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, only parameter with the same name as the key
will have it.
- <name of module member>.<something>: recursive. Same as if <something>
was used in param_groups of that submodule/member.
Args:
module: module from which to extract the parameters and their parameter
@ -277,7 +278,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
param_groups = defaultdict(list)
def traverse(module, default_group):
def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
"""
Visitor for module to assign its parameters to the relevant member of
param_groups.
Args:
module: the module being visited in a depth-first search
default_group: the param group to assign parameters to unless
otherwise overriden.
mapping: known mappings of parameters to groups for this module,
destructively modified by this function.
"""
# If key self is defined in param_groups then chenge the default param
# group for all parameters and children in the module.
if hasattr(module, "param_groups") and "self" in module.param_groups:
@ -286,25 +298,26 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
# Collect all the parameters that are directly inside the `module`,
# they will be in the default param group if they don't have
# defined group.
if hasattr(module, "param_groups"):
mapping.update(module.param_groups)
for name, param in module.named_parameters(recurse=False):
if param.requires_grad:
if hasattr(module, "param_groups") and name in module.param_groups:
param_groups[module.param_groups[name]].append(param)
else:
param_groups[default_group].append(param)
group_name = mapping.get(name, default_group)
logger.info(f"Assigning {name} to param_group {group_name}")
param_groups[group_name].append(param)
# If children have defined default param group then use it else pass
# own default.
for child_name, child in module.named_children():
if (
hasattr(module, "param_groups")
and child_name in module.param_groups
):
traverse(child, module.param_groups[child_name])
else:
traverse(child, default_group)
mapping_to_add = {
name[len(child_name) + 1 :]: group
for name, group in mapping.items()
if name.startswith(child_name + ".")
}
traverse(child, mapping.get(child_name, default_group), mapping_to_add)
traverse(module, "default")
traverse(module, "default", {})
return param_groups
def _get_group_learning_rate(self, group_name: str) -> float:

View File

@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
import torch
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
from ..impl.optimizer_factory import (
ImplicitronOptimizerFactory,
logger as factory_logger,
)
internal = os.environ.get("FB_TEST", False)
@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
def _get_param_groups(self, model):
default_cfg = get_default_args(ImplicitronOptimizerFactory)
factory = ImplicitronOptimizerFactory(default_cfg)
return factory._get_param_groups(model)
oldlevel = factory_logger.level
factory_logger.setLevel(logging.ERROR)
out = factory._get_param_groups(model)
factory_logger.setLevel(oldlevel)
return out
def _assert_allin(self, a, param_groups, key):
"""
Asserts that all the parameters in a are in the group
named by key.
"""
with self.subTest(f"Testing key {key}"):
b = param_groups[key]
for el in a:
@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default")
def test_double_dotted(self):
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
na = Node(params=[pa, pb])
nb = Node(children=[na])
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa], param_groups, "X")
self._assert_allin([pb], param_groups, "Y")
def test_tree_param_groups_defined(self):
"""
Test generic tree assignment.