mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
allow dots in param_groups
Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
parent
a1f2ded58a
commit
7be49bf46f
@ -258,15 +258,16 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
at the module level. Possible keys, including the "self" key do not have to
|
||||
be defined. By default all parameters have the learning rate defined in the
|
||||
optimizer. This can be overridden by setting the parameter group in `param_groups`
|
||||
member of a specific module, it can be overridden at the:
|
||||
- module level with “self” key, all the parameters and child
|
||||
module's parameters will inherit it
|
||||
- member level, which is the same as if the `param_groups` in that
|
||||
member has key=“self” and value equal to that parameter group.
|
||||
member of a specific module. Values are a parameter group name. The keys
|
||||
specify what parameters will be affected as follows:
|
||||
- “self”: All the parameters of the module and its child modules
|
||||
- name of a parameter: A parameter with that name.
|
||||
- name of a module member: All the parameters of the module and its
|
||||
child modules.
|
||||
This is useful if members do not have `param_groups`, for
|
||||
example torch.nn.Linear.
|
||||
- parameter level, only parameter with the same name as the key
|
||||
will have it.
|
||||
- <name of module member>.<something>: recursive. Same as if <something>
|
||||
was used in param_groups of that submodule/member.
|
||||
|
||||
Args:
|
||||
module: module from which to extract the parameters and their parameter
|
||||
@ -277,7 +278,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
|
||||
param_groups = defaultdict(list)
|
||||
|
||||
def traverse(module, default_group):
|
||||
def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
|
||||
"""
|
||||
Visitor for module to assign its parameters to the relevant member of
|
||||
param_groups.
|
||||
|
||||
Args:
|
||||
module: the module being visited in a depth-first search
|
||||
default_group: the param group to assign parameters to unless
|
||||
otherwise overriden.
|
||||
mapping: known mappings of parameters to groups for this module,
|
||||
destructively modified by this function.
|
||||
"""
|
||||
# If key self is defined in param_groups then chenge the default param
|
||||
# group for all parameters and children in the module.
|
||||
if hasattr(module, "param_groups") and "self" in module.param_groups:
|
||||
@ -286,25 +298,26 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
# Collect all the parameters that are directly inside the `module`,
|
||||
# they will be in the default param group if they don't have
|
||||
# defined group.
|
||||
if hasattr(module, "param_groups"):
|
||||
mapping.update(module.param_groups)
|
||||
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.requires_grad:
|
||||
if hasattr(module, "param_groups") and name in module.param_groups:
|
||||
param_groups[module.param_groups[name]].append(param)
|
||||
else:
|
||||
param_groups[default_group].append(param)
|
||||
group_name = mapping.get(name, default_group)
|
||||
logger.info(f"Assigning {name} to param_group {group_name}")
|
||||
param_groups[group_name].append(param)
|
||||
|
||||
# If children have defined default param group then use it else pass
|
||||
# own default.
|
||||
for child_name, child in module.named_children():
|
||||
if (
|
||||
hasattr(module, "param_groups")
|
||||
and child_name in module.param_groups
|
||||
):
|
||||
traverse(child, module.param_groups[child_name])
|
||||
else:
|
||||
traverse(child, default_group)
|
||||
mapping_to_add = {
|
||||
name[len(child_name) + 1 :]: group
|
||||
for name, group in mapping.items()
|
||||
if name.startswith(child_name + ".")
|
||||
}
|
||||
traverse(child, mapping.get(child_name, default_group), mapping_to_add)
|
||||
|
||||
traverse(module, "default")
|
||||
traverse(module, "default", {})
|
||||
return param_groups
|
||||
|
||||
def _get_group_learning_rate(self, group_name: str) -> float:
|
||||
|
@ -4,13 +4,17 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
|
||||
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
|
||||
from ..impl.optimizer_factory import (
|
||||
ImplicitronOptimizerFactory,
|
||||
logger as factory_logger,
|
||||
)
|
||||
|
||||
internal = os.environ.get("FB_TEST", False)
|
||||
|
||||
@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
def _get_param_groups(self, model):
|
||||
default_cfg = get_default_args(ImplicitronOptimizerFactory)
|
||||
factory = ImplicitronOptimizerFactory(default_cfg)
|
||||
return factory._get_param_groups(model)
|
||||
oldlevel = factory_logger.level
|
||||
factory_logger.setLevel(logging.ERROR)
|
||||
out = factory._get_param_groups(model)
|
||||
factory_logger.setLevel(oldlevel)
|
||||
return out
|
||||
|
||||
def _assert_allin(self, a, param_groups, key):
|
||||
"""
|
||||
Asserts that all the parameters in a are in the group
|
||||
named by key.
|
||||
"""
|
||||
with self.subTest(f"Testing key {key}"):
|
||||
b = param_groups[key]
|
||||
for el in a:
|
||||
@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
param_groups = self._get_param_groups(root)
|
||||
self._assert_allin([pa, pb, pc], param_groups, "default")
|
||||
|
||||
def test_double_dotted(self):
|
||||
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
|
||||
na = Node(params=[pa, pb])
|
||||
nb = Node(children=[na])
|
||||
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
self._assert_allin([pa], param_groups, "X")
|
||||
self._assert_allin([pb], param_groups, "Y")
|
||||
|
||||
def test_tree_param_groups_defined(self):
|
||||
"""
|
||||
Test generic tree assignment.
|
||||
|
Loading…
x
Reference in New Issue
Block a user