allow dots in param_groups

Summary:
Allow a module's param_group member to specify overrides to the param groups of its members or their members.
Also logging for param group assignments.

This allows defining `params.basis_matrix` in the param_groups of a voxel_grid.

Reviewed By: shapovalov

Differential Revision: D41080667

fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
Jeremy Reizenstein 2022-11-07 06:41:40 -08:00 committed by Facebook GitHub Bot
parent a1f2ded58a
commit 7be49bf46f
2 changed files with 56 additions and 22 deletions

View File

@ -258,15 +258,16 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
at the module level. Possible keys, including the "self" key do not have to at the module level. Possible keys, including the "self" key do not have to
be defined. By default all parameters have the learning rate defined in the be defined. By default all parameters have the learning rate defined in the
optimizer. This can be overridden by setting the parameter group in `param_groups` optimizer. This can be overridden by setting the parameter group in `param_groups`
member of a specific module, it can be overridden at the: member of a specific module. Values are a parameter group name. The keys
- module level with self key, all the parameters and child specify what parameters will be affected as follows:
module's parameters will inherit it - self: All the parameters of the module and its child modules
- member level, which is the same as if the `param_groups` in that - name of a parameter: A parameter with that name.
member has key=self and value equal to that parameter group. - name of a module member: All the parameters of the module and its
child modules.
This is useful if members do not have `param_groups`, for This is useful if members do not have `param_groups`, for
example torch.nn.Linear. example torch.nn.Linear.
- parameter level, only parameter with the same name as the key - <name of module member>.<something>: recursive. Same as if <something>
will have it. was used in param_groups of that submodule/member.
Args: Args:
module: module from which to extract the parameters and their parameter module: module from which to extract the parameters and their parameter
@ -277,7 +278,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
param_groups = defaultdict(list) param_groups = defaultdict(list)
def traverse(module, default_group): def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
"""
Visitor for module to assign its parameters to the relevant member of
param_groups.
Args:
module: the module being visited in a depth-first search
default_group: the param group to assign parameters to unless
otherwise overriden.
mapping: known mappings of parameters to groups for this module,
destructively modified by this function.
"""
# If key self is defined in param_groups then chenge the default param # If key self is defined in param_groups then chenge the default param
# group for all parameters and children in the module. # group for all parameters and children in the module.
if hasattr(module, "param_groups") and "self" in module.param_groups: if hasattr(module, "param_groups") and "self" in module.param_groups:
@ -286,25 +298,26 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
# Collect all the parameters that are directly inside the `module`, # Collect all the parameters that are directly inside the `module`,
# they will be in the default param group if they don't have # they will be in the default param group if they don't have
# defined group. # defined group.
if hasattr(module, "param_groups"):
mapping.update(module.param_groups)
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if param.requires_grad: if param.requires_grad:
if hasattr(module, "param_groups") and name in module.param_groups: group_name = mapping.get(name, default_group)
param_groups[module.param_groups[name]].append(param) logger.info(f"Assigning {name} to param_group {group_name}")
else: param_groups[group_name].append(param)
param_groups[default_group].append(param)
# If children have defined default param group then use it else pass # If children have defined default param group then use it else pass
# own default. # own default.
for child_name, child in module.named_children(): for child_name, child in module.named_children():
if ( mapping_to_add = {
hasattr(module, "param_groups") name[len(child_name) + 1 :]: group
and child_name in module.param_groups for name, group in mapping.items()
): if name.startswith(child_name + ".")
traverse(child, module.param_groups[child_name]) }
else: traverse(child, mapping.get(child_name, default_group), mapping_to_add)
traverse(child, default_group)
traverse(module, "default") traverse(module, "default", {})
return param_groups return param_groups
def _get_group_learning_rate(self, group_name: str) -> float: def _get_group_learning_rate(self, group_name: str) -> float:

View File

@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import os import os
import unittest import unittest
import torch import torch
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from ..impl.optimizer_factory import ImplicitronOptimizerFactory from ..impl.optimizer_factory import (
ImplicitronOptimizerFactory,
logger as factory_logger,
)
internal = os.environ.get("FB_TEST", False) internal = os.environ.get("FB_TEST", False)
@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
def _get_param_groups(self, model): def _get_param_groups(self, model):
default_cfg = get_default_args(ImplicitronOptimizerFactory) default_cfg = get_default_args(ImplicitronOptimizerFactory)
factory = ImplicitronOptimizerFactory(default_cfg) factory = ImplicitronOptimizerFactory(default_cfg)
return factory._get_param_groups(model) oldlevel = factory_logger.level
factory_logger.setLevel(logging.ERROR)
out = factory._get_param_groups(model)
factory_logger.setLevel(oldlevel)
return out
def _assert_allin(self, a, param_groups, key): def _assert_allin(self, a, param_groups, key):
"""
Asserts that all the parameters in a are in the group
named by key.
"""
with self.subTest(f"Testing key {key}"): with self.subTest(f"Testing key {key}"):
b = param_groups[key] b = param_groups[key]
for el in a: for el in a:
@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
param_groups = self._get_param_groups(root) param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default") self._assert_allin([pa, pb, pc], param_groups, "default")
def test_double_dotted(self):
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
na = Node(params=[pa, pb])
nb = Node(children=[na])
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa], param_groups, "X")
self._assert_allin([pb], param_groups, "Y")
def test_tree_param_groups_defined(self): def test_tree_param_groups_defined(self):
""" """
Test generic tree assignment. Test generic tree assignment.