allow dots in param_groups

Summary:
Allow a module's param_group member to specify overrides to the param groups of its members or their members.
Also logging for param group assignments.

This allows defining `params.basis_matrix` in the param_groups of a voxel_grid.

Reviewed By: shapovalov

Differential Revision: D41080667

fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
Jeremy Reizenstein
2022-11-07 06:41:40 -08:00
committed by Facebook GitHub Bot
parent a1f2ded58a
commit 7be49bf46f
2 changed files with 56 additions and 22 deletions

View File

@@ -4,13 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import unittest
import torch
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
from ..impl.optimizer_factory import (
ImplicitronOptimizerFactory,
logger as factory_logger,
)
internal = os.environ.get("FB_TEST", False)
@@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
def _get_param_groups(self, model):
default_cfg = get_default_args(ImplicitronOptimizerFactory)
factory = ImplicitronOptimizerFactory(default_cfg)
return factory._get_param_groups(model)
oldlevel = factory_logger.level
factory_logger.setLevel(logging.ERROR)
out = factory._get_param_groups(model)
factory_logger.setLevel(oldlevel)
return out
def _assert_allin(self, a, param_groups, key):
"""
Asserts that all the parameters in a are in the group
named by key.
"""
with self.subTest(f"Testing key {key}"):
b = param_groups[key]
for el in a:
@@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
param_groups = self._get_param_groups(root)
self._assert_allin([pa, pb, pc], param_groups, "default")
def test_double_dotted(self):
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
na = Node(params=[pa, pb])
nb = Node(children=[na])
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
param_groups = self._get_param_groups(root)
self._assert_allin([pa], param_groups, "X")
self._assert_allin([pb], param_groups, "Y")
def test_tree_param_groups_defined(self):
"""
Test generic tree assignment.