mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
allow dots in param_groups
Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a1f2ded58a
commit
7be49bf46f
@@ -4,13 +4,17 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
|
||||
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
|
||||
from ..impl.optimizer_factory import (
|
||||
ImplicitronOptimizerFactory,
|
||||
logger as factory_logger,
|
||||
)
|
||||
|
||||
internal = os.environ.get("FB_TEST", False)
|
||||
|
||||
@@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
def _get_param_groups(self, model):
|
||||
default_cfg = get_default_args(ImplicitronOptimizerFactory)
|
||||
factory = ImplicitronOptimizerFactory(default_cfg)
|
||||
return factory._get_param_groups(model)
|
||||
oldlevel = factory_logger.level
|
||||
factory_logger.setLevel(logging.ERROR)
|
||||
out = factory._get_param_groups(model)
|
||||
factory_logger.setLevel(oldlevel)
|
||||
return out
|
||||
|
||||
def _assert_allin(self, a, param_groups, key):
|
||||
"""
|
||||
Asserts that all the parameters in a are in the group
|
||||
named by key.
|
||||
"""
|
||||
with self.subTest(f"Testing key {key}"):
|
||||
b = param_groups[key]
|
||||
for el in a:
|
||||
@@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
|
||||
param_groups = self._get_param_groups(root)
|
||||
self._assert_allin([pa, pb, pc], param_groups, "default")
|
||||
|
||||
def test_double_dotted(self):
|
||||
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
|
||||
na = Node(params=[pa, pb])
|
||||
nb = Node(children=[na])
|
||||
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
|
||||
param_groups = self._get_param_groups(root)
|
||||
self._assert_allin([pa], param_groups, "X")
|
||||
self._assert_allin([pb], param_groups, "Y")
|
||||
|
||||
def test_tree_param_groups_defined(self):
|
||||
"""
|
||||
Test generic tree assignment.
|
||||
|
||||
Reference in New Issue
Block a user