mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
allow dots in param_groups
Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
parent
a1f2ded58a
commit
7be49bf46f
@ -258,15 +258,16 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
at the module level. Possible keys, including the "self" key do not have to
|
at the module level. Possible keys, including the "self" key do not have to
|
||||||
be defined. By default all parameters have the learning rate defined in the
|
be defined. By default all parameters have the learning rate defined in the
|
||||||
optimizer. This can be overridden by setting the parameter group in `param_groups`
|
optimizer. This can be overridden by setting the parameter group in `param_groups`
|
||||||
member of a specific module, it can be overridden at the:
|
member of a specific module. Values are a parameter group name. The keys
|
||||||
- module level with “self” key, all the parameters and child
|
specify what parameters will be affected as follows:
|
||||||
module's parameters will inherit it
|
- “self”: All the parameters of the module and its child modules
|
||||||
- member level, which is the same as if the `param_groups` in that
|
- name of a parameter: A parameter with that name.
|
||||||
member has key=“self” and value equal to that parameter group.
|
- name of a module member: All the parameters of the module and its
|
||||||
|
child modules.
|
||||||
This is useful if members do not have `param_groups`, for
|
This is useful if members do not have `param_groups`, for
|
||||||
example torch.nn.Linear.
|
example torch.nn.Linear.
|
||||||
- parameter level, only parameter with the same name as the key
|
- <name of module member>.<something>: recursive. Same as if <something>
|
||||||
will have it.
|
was used in param_groups of that submodule/member.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module: module from which to extract the parameters and their parameter
|
module: module from which to extract the parameters and their parameter
|
||||||
@ -277,7 +278,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
|
|
||||||
param_groups = defaultdict(list)
|
param_groups = defaultdict(list)
|
||||||
|
|
||||||
def traverse(module, default_group):
|
def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Visitor for module to assign its parameters to the relevant member of
|
||||||
|
param_groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module: the module being visited in a depth-first search
|
||||||
|
default_group: the param group to assign parameters to unless
|
||||||
|
otherwise overriden.
|
||||||
|
mapping: known mappings of parameters to groups for this module,
|
||||||
|
destructively modified by this function.
|
||||||
|
"""
|
||||||
# If key self is defined in param_groups then chenge the default param
|
# If key self is defined in param_groups then chenge the default param
|
||||||
# group for all parameters and children in the module.
|
# group for all parameters and children in the module.
|
||||||
if hasattr(module, "param_groups") and "self" in module.param_groups:
|
if hasattr(module, "param_groups") and "self" in module.param_groups:
|
||||||
@ -286,25 +298,26 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
# Collect all the parameters that are directly inside the `module`,
|
# Collect all the parameters that are directly inside the `module`,
|
||||||
# they will be in the default param group if they don't have
|
# they will be in the default param group if they don't have
|
||||||
# defined group.
|
# defined group.
|
||||||
|
if hasattr(module, "param_groups"):
|
||||||
|
mapping.update(module.param_groups)
|
||||||
|
|
||||||
for name, param in module.named_parameters(recurse=False):
|
for name, param in module.named_parameters(recurse=False):
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if hasattr(module, "param_groups") and name in module.param_groups:
|
group_name = mapping.get(name, default_group)
|
||||||
param_groups[module.param_groups[name]].append(param)
|
logger.info(f"Assigning {name} to param_group {group_name}")
|
||||||
else:
|
param_groups[group_name].append(param)
|
||||||
param_groups[default_group].append(param)
|
|
||||||
|
|
||||||
# If children have defined default param group then use it else pass
|
# If children have defined default param group then use it else pass
|
||||||
# own default.
|
# own default.
|
||||||
for child_name, child in module.named_children():
|
for child_name, child in module.named_children():
|
||||||
if (
|
mapping_to_add = {
|
||||||
hasattr(module, "param_groups")
|
name[len(child_name) + 1 :]: group
|
||||||
and child_name in module.param_groups
|
for name, group in mapping.items()
|
||||||
):
|
if name.startswith(child_name + ".")
|
||||||
traverse(child, module.param_groups[child_name])
|
}
|
||||||
else:
|
traverse(child, mapping.get(child_name, default_group), mapping_to_add)
|
||||||
traverse(child, default_group)
|
|
||||||
|
|
||||||
traverse(module, "default")
|
traverse(module, "default", {})
|
||||||
return param_groups
|
return param_groups
|
||||||
|
|
||||||
def _get_group_learning_rate(self, group_name: str) -> float:
|
def _get_group_learning_rate(self, group_name: str) -> float:
|
||||||
|
@ -4,13 +4,17 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||||
|
|
||||||
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
|
from ..impl.optimizer_factory import (
|
||||||
|
ImplicitronOptimizerFactory,
|
||||||
|
logger as factory_logger,
|
||||||
|
)
|
||||||
|
|
||||||
internal = os.environ.get("FB_TEST", False)
|
internal = os.environ.get("FB_TEST", False)
|
||||||
|
|
||||||
@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
|
|||||||
def _get_param_groups(self, model):
|
def _get_param_groups(self, model):
|
||||||
default_cfg = get_default_args(ImplicitronOptimizerFactory)
|
default_cfg = get_default_args(ImplicitronOptimizerFactory)
|
||||||
factory = ImplicitronOptimizerFactory(default_cfg)
|
factory = ImplicitronOptimizerFactory(default_cfg)
|
||||||
return factory._get_param_groups(model)
|
oldlevel = factory_logger.level
|
||||||
|
factory_logger.setLevel(logging.ERROR)
|
||||||
|
out = factory._get_param_groups(model)
|
||||||
|
factory_logger.setLevel(oldlevel)
|
||||||
|
return out
|
||||||
|
|
||||||
def _assert_allin(self, a, param_groups, key):
|
def _assert_allin(self, a, param_groups, key):
|
||||||
|
"""
|
||||||
|
Asserts that all the parameters in a are in the group
|
||||||
|
named by key.
|
||||||
|
"""
|
||||||
with self.subTest(f"Testing key {key}"):
|
with self.subTest(f"Testing key {key}"):
|
||||||
b = param_groups[key]
|
b = param_groups[key]
|
||||||
for el in a:
|
for el in a:
|
||||||
@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
|
|||||||
param_groups = self._get_param_groups(root)
|
param_groups = self._get_param_groups(root)
|
||||||
self._assert_allin([pa, pb, pc], param_groups, "default")
|
self._assert_allin([pa, pb, pc], param_groups, "default")
|
||||||
|
|
||||||
|
def test_double_dotted(self):
|
||||||
|
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
|
||||||
|
na = Node(params=[pa, pb])
|
||||||
|
nb = Node(children=[na])
|
||||||
|
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
|
||||||
|
param_groups = self._get_param_groups(root)
|
||||||
|
self._assert_allin([pa], param_groups, "X")
|
||||||
|
self._assert_allin([pb], param_groups, "Y")
|
||||||
|
|
||||||
def test_tree_param_groups_defined(self):
|
def test_tree_param_groups_defined(self):
|
||||||
"""
|
"""
|
||||||
Test generic tree assignment.
|
Test generic tree assignment.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user