mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	allow dots in param_groups
Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
This commit is contained in:
		
							parent
							
								
									a1f2ded58a
								
							
						
					
					
						commit
						7be49bf46f
					
				@ -258,15 +258,16 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
        at the module level. Possible keys, including the "self" key do not have to
 | 
			
		||||
        be defined. By default all parameters have the learning rate defined in the
 | 
			
		||||
        optimizer. This can be overridden by setting the parameter group in `param_groups`
 | 
			
		||||
        member of a specific module, it can be overridden at the:
 | 
			
		||||
            - module level with “self” key, all the parameters and child
 | 
			
		||||
                module's parameters will inherit it
 | 
			
		||||
            - member level, which is the same as if the `param_groups` in that
 | 
			
		||||
                member has key=“self” and value equal to that parameter group.
 | 
			
		||||
        member of a specific module. Values are a parameter group name. The keys
 | 
			
		||||
        specify what parameters will be affected as follows:
 | 
			
		||||
            - “self”: All the parameters of the module and its child modules
 | 
			
		||||
            - name of a parameter: A parameter with that name.
 | 
			
		||||
            - name of a module member: All the parameters of the module and its
 | 
			
		||||
                child modules.
 | 
			
		||||
                This is useful if members do not have `param_groups`, for
 | 
			
		||||
                example torch.nn.Linear.
 | 
			
		||||
            - parameter level, only parameter with the same name as the key
 | 
			
		||||
                will have it.
 | 
			
		||||
            - <name of module member>.<something>: recursive. Same as if <something>
 | 
			
		||||
                was used in param_groups of that submodule/member.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            module: module from which to extract the parameters and their parameter
 | 
			
		||||
@ -277,7 +278,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
 | 
			
		||||
        param_groups = defaultdict(list)
 | 
			
		||||
 | 
			
		||||
        def traverse(module, default_group):
 | 
			
		||||
        def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
 | 
			
		||||
            """
 | 
			
		||||
            Visitor for module to assign its parameters to the relevant member of
 | 
			
		||||
            param_groups.
 | 
			
		||||
 | 
			
		||||
            Args:
 | 
			
		||||
                module: the module being visited in a depth-first search
 | 
			
		||||
                default_group: the param group to assign parameters to unless
 | 
			
		||||
                                otherwise overriden.
 | 
			
		||||
                mapping: known mappings of parameters to groups for this module,
 | 
			
		||||
                    destructively modified by this function.
 | 
			
		||||
            """
 | 
			
		||||
            # If key self is defined in param_groups then chenge the default param
 | 
			
		||||
            # group for all parameters and children in the module.
 | 
			
		||||
            if hasattr(module, "param_groups") and "self" in module.param_groups:
 | 
			
		||||
@ -286,25 +298,26 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
            # Collect all the parameters that are directly inside the `module`,
 | 
			
		||||
            # they will be in the default param group if they don't have
 | 
			
		||||
            # defined group.
 | 
			
		||||
            if hasattr(module, "param_groups"):
 | 
			
		||||
                mapping.update(module.param_groups)
 | 
			
		||||
 | 
			
		||||
            for name, param in module.named_parameters(recurse=False):
 | 
			
		||||
                if param.requires_grad:
 | 
			
		||||
                    if hasattr(module, "param_groups") and name in module.param_groups:
 | 
			
		||||
                        param_groups[module.param_groups[name]].append(param)
 | 
			
		||||
                    else:
 | 
			
		||||
                        param_groups[default_group].append(param)
 | 
			
		||||
                    group_name = mapping.get(name, default_group)
 | 
			
		||||
                    logger.info(f"Assigning {name} to param_group {group_name}")
 | 
			
		||||
                    param_groups[group_name].append(param)
 | 
			
		||||
 | 
			
		||||
            # If children have defined default param group then use it else pass
 | 
			
		||||
            # own default.
 | 
			
		||||
            for child_name, child in module.named_children():
 | 
			
		||||
                if (
 | 
			
		||||
                    hasattr(module, "param_groups")
 | 
			
		||||
                    and child_name in module.param_groups
 | 
			
		||||
                ):
 | 
			
		||||
                    traverse(child, module.param_groups[child_name])
 | 
			
		||||
                else:
 | 
			
		||||
                    traverse(child, default_group)
 | 
			
		||||
                mapping_to_add = {
 | 
			
		||||
                    name[len(child_name) + 1 :]: group
 | 
			
		||||
                    for name, group in mapping.items()
 | 
			
		||||
                    if name.startswith(child_name + ".")
 | 
			
		||||
                }
 | 
			
		||||
                traverse(child, mapping.get(child_name, default_group), mapping_to_add)
 | 
			
		||||
 | 
			
		||||
        traverse(module, "default")
 | 
			
		||||
        traverse(module, "default", {})
 | 
			
		||||
        return param_groups
 | 
			
		||||
 | 
			
		||||
    def _get_group_learning_rate(self, group_name: str) -> float:
 | 
			
		||||
 | 
			
		||||
@ -4,13 +4,17 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
 | 
			
		||||
 | 
			
		||||
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
 | 
			
		||||
from ..impl.optimizer_factory import (
 | 
			
		||||
    ImplicitronOptimizerFactory,
 | 
			
		||||
    logger as factory_logger,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
internal = os.environ.get("FB_TEST", False)
 | 
			
		||||
 | 
			
		||||
@ -23,9 +27,17 @@ class TestOptimizerFactory(unittest.TestCase):
 | 
			
		||||
    def _get_param_groups(self, model):
 | 
			
		||||
        default_cfg = get_default_args(ImplicitronOptimizerFactory)
 | 
			
		||||
        factory = ImplicitronOptimizerFactory(default_cfg)
 | 
			
		||||
        return factory._get_param_groups(model)
 | 
			
		||||
        oldlevel = factory_logger.level
 | 
			
		||||
        factory_logger.setLevel(logging.ERROR)
 | 
			
		||||
        out = factory._get_param_groups(model)
 | 
			
		||||
        factory_logger.setLevel(oldlevel)
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def _assert_allin(self, a, param_groups, key):
 | 
			
		||||
        """
 | 
			
		||||
        Asserts that all the parameters in a are in the group
 | 
			
		||||
        named by key.
 | 
			
		||||
        """
 | 
			
		||||
        with self.subTest(f"Testing key {key}"):
 | 
			
		||||
            b = param_groups[key]
 | 
			
		||||
            for el in a:
 | 
			
		||||
@ -83,6 +95,15 @@ class TestOptimizerFactory(unittest.TestCase):
 | 
			
		||||
        param_groups = self._get_param_groups(root)
 | 
			
		||||
        self._assert_allin([pa, pb, pc], param_groups, "default")
 | 
			
		||||
 | 
			
		||||
    def test_double_dotted(self):
 | 
			
		||||
        pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
 | 
			
		||||
        na = Node(params=[pa, pb])
 | 
			
		||||
        nb = Node(children=[na])
 | 
			
		||||
        root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
 | 
			
		||||
        param_groups = self._get_param_groups(root)
 | 
			
		||||
        self._assert_allin([pa], param_groups, "X")
 | 
			
		||||
        self._assert_allin([pb], param_groups, "Y")
 | 
			
		||||
 | 
			
		||||
    def test_tree_param_groups_defined(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test generic tree assignment.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user