diff --git a/projects/implicitron_trainer/impl/optimizer_factory.py b/projects/implicitron_trainer/impl/optimizer_factory.py index 9e4a5227..8cd75884 100644 --- a/projects/implicitron_trainer/impl/optimizer_factory.py +++ b/projects/implicitron_trainer/impl/optimizer_factory.py @@ -7,7 +7,9 @@ import inspect import logging import os -from typing import Any, Dict, Optional, Tuple +from collections import defaultdict +from dataclasses import field +from typing import Any, Dict, List, Optional, Tuple import torch.optim @@ -64,6 +66,12 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): weight_decay: The optimizer weight_decay (L2 penalty on model weights). foreach: Whether to use new "foreach" implementation of optimizer where available (e.g. requires PyTorch 1.12.0 for Adam) + group_learning_rates: Parameters or modules can be assigned to parameter + groups. This dictionary has names of those parameter groups as keys + and learning rates as values. All parameter group names have to be + defined in this dictionary. Parameters which do not have predefined + parameter group are put into "default" parameter group which has + `lr` as its learning rate. """ betas: Tuple[float, ...] = (0.9, 0.999) @@ -78,6 +86,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): linear_exponential_lr_milestone: int = 200 linear_exponential_start_gamma: float = 0.1 foreach: Optional[bool] = True + group_learning_rates: Dict[str, float] = field(default_factory=lambda: {}) def __post_init__(self): run_auto_creation(self) @@ -115,8 +124,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): # pyre-ignore[29] p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) else: - allprm = [prm for prm in model.parameters() if prm.requires_grad] - p_groups = [{"params": allprm, "lr": self.lr}] + p_groups = [ + {"params": params, "lr": self._get_group_learning_rate(group)} + for group, params in self._get_param_groups(model).items() + ] # Intialize the optimizer optimizer_kwargs: Dict[str, Any] = { @@ -233,3 +244,82 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): else: raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.") return optimizer_state + + def _get_param_groups( + self, module: torch.nn.Module + ) -> Dict[str, List[torch.nn.Parameter]]: + """ + Recursively visits all the modules inside the `module` and sorts all the + parameters in parameter groups. + + Uses `param_groups` dictionary member, where keys are names of individual + parameters or module members and values are the names of the parameter groups + for those parameters or members. "self" key is used to denote the parameter groups + at the module level. Possible keys, including the "self" key do not have to + be defined. By default all parameters have the learning rate defined in the + optimizer. This can be overridden by setting the parameter group in `param_groups` + member of a specific module, it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will inherit it + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, only parameter with the same name as the key + will have it. + + Args: + module: module from which to extract the parameters and their parameter + groups + Returns: + dictionary with parameter groups as keys and lists of parameters as values + """ + + param_groups = defaultdict(list) + + def traverse(module, default_group): + # If key self is defined in param_groups then chenge the default param + # group for all parameters and children in the module. + if hasattr(module, "param_groups") and "self" in module.param_groups: + default_group = module.param_groups["self"] + + # Collect all the parameters that are directly inside the `module`, + # they will be in the default param group if they don't have + # defined group. + for name, param in module.named_parameters(recurse=False): + if param.requires_grad: + if hasattr(module, "param_groups") and name in module.param_groups: + param_groups[module.param_groups[name]].append(param) + else: + param_groups[default_group].append(param) + + # If children have defined default param group then use it else pass + # own default. + for child_name, child in module.named_children(): + if ( + hasattr(module, "param_groups") + and child_name in module.param_groups + ): + traverse(child, module.param_groups[child_name]) + else: + traverse(child, default_group) + + traverse(module, "default") + return param_groups + + def _get_group_learning_rate(self, group_name: str) -> float: + """ + Wraps the `group_learning_rates` dictionary providing errors and returns + `self.lr` for "default" group_name. + + Args: + group_name: a string representing the name of the group + Returns: + learning rate for a specific group + """ + if group_name == "default": + return self.lr + lr = self.group_learning_rates.get(group_name, None) + if lr is None: + raise ValueError(f"no learning rate given for group {group_name}") + return lr diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index d6b6beed..f9e6e329 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -409,6 +409,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args: linear_exponential_lr_milestone: 200 linear_exponential_start_gamma: 0.1 foreach: true + group_learning_rates: {} training_loop_ImplicitronTrainingLoop_args: evaluator_class_type: ImplicitronEvaluator evaluator_ImplicitronEvaluator_args: diff --git a/projects/implicitron_trainer/tests/test_optimizer_factory.py b/projects/implicitron_trainer/tests/test_optimizer_factory.py new file mode 100644 index 00000000..23cc5aad --- /dev/null +++ b/projects/implicitron_trainer/tests/test_optimizer_factory.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args + +from ..impl.optimizer_factory import ImplicitronOptimizerFactory + +internal = os.environ.get("FB_TEST", False) + + +class TestOptimizerFactory(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + expand_args_fields(ImplicitronOptimizerFactory) + + def _get_param_groups(self, model): + default_cfg = get_default_args(ImplicitronOptimizerFactory) + factory = ImplicitronOptimizerFactory(default_cfg) + return factory._get_param_groups(model) + + def _assert_allin(self, a, param_groups, key): + with self.subTest(f"Testing key {key}"): + b = param_groups[key] + for el in a: + if el not in b: + raise ValueError( + f"Element {el}\n\n from:\n\n {a}\n\n not in:\n\n {b}\n\n." + + f" Full param groups = \n\n{param_groups}" + ) + for el in b: + if el not in a: + raise ValueError( + f"Element {el}\n\n from:\n\n {b}\n\n not in:\n\n {a}\n\n." + + f" Full param groups = \n\n{param_groups}" + ) + + def test_default_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc]) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pb, pc], param_groups, "default") + + def test_member_overrides_default_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb") + + def test_self_overrides_member_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb], param_groups={"self": "pb_self"}) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb_self") + assert len(param_groups["pb_member"]) == 0, param_groups + + def test_param_overrides_self_param_group_assignment(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node( + params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"} + ) + root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"}) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pc], param_groups, "default") + self._assert_allin([pb], param_groups, "pb_self") + assert len(param_groups["pb_member"]) == 0, param_groups + + def test_no_param_groups_defined(self): + pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] + na, nb = Node(params=[pa]), Node(params=[pb]) + root = Node(children=[na, nb], params=[pc]) + param_groups = self._get_param_groups(root) + self._assert_allin([pa, pb, pc], param_groups, "default") + + def test_tree_param_groups_defined(self): + """ + Test generic tree assignment. + + A0 + |--------------------------- + | | | + Bb M J- + |----- |------- + | | | | + C Ddg K Ll + |-------------- + | | | | + E4 Ff G H- + + All nodes have one parameter. Character next to the capital + letter means they have added something to their `parameter_groups`: + - small letter same as capital means self is set to that letter + - small letter different then capital means that member is set + (the one that is named like that) + - number means parameter's parameter_group is set like that + - "-" means it does not have `parameter_groups` member + """ + p = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(12)] + L = Node(params=[p[11]], param_groups={"self": "l"}) + K = Node(params=[p[10]], param_groups={}) + J = Node(params=[p[9]], param_groups=None, children=[K, L]) + M = Node(params=[p[8]], param_groups={}) + + E = Node(params=[p[4]], param_groups={"p0": "4"}) + F = Node(params=[p[5]], param_groups={"self": "f"}) + G = Node(params=[p[6]], param_groups={}) + H = Node(params=[p[7]], param_groups=None) + + D = Node( + params=[p[3]], param_groups={"self": "d", "m2": "g"}, children=[E, F, G, H] + ) + C = Node(params=[p[2]], param_groups={}) + + B = Node(params=[p[1]], param_groups={"self": "b"}, children=[C, D]) + + A = Node(params=[p[0]], param_groups={"p0": "0"}, children=[B, M, J]) + + param_groups = self._get_param_groups(A) + + # if parts of the group belong to two different categories assert is repeated + # parameter level + self._assert_allin([p[0]], param_groups, "0") + self._assert_allin([p[4]], param_groups, "4") + # self level + self._assert_allin([p[5]], param_groups, "f") + self._assert_allin([p[11]], param_groups, "l") + self._assert_allin([p[2], p[1]], param_groups, "b") + self._assert_allin([p[7], p[3]], param_groups, "d") + # member level + self._assert_allin([p[6]], param_groups, "g") + # inherit level + self._assert_allin([p[7], p[3]], param_groups, "d") + self._assert_allin([p[2], p[1]], param_groups, "b") + # default level + self._assert_allin([p[8], p[9], p[10]], param_groups, "default") + + +class Node(torch.nn.Module): + def __init__(self, children=(), params=(), param_groups=None): + super().__init__() + for i, child in enumerate(children): + self.add_module("m" + str(i), child) + for i, param in enumerate(params): + setattr(self, "p" + str(i), param) + if param_groups is not None: + self.param_groups = param_groups + + def __str__(self): + return ( + "modules:\n" + str(self._modules) + "\nparameters\n" + str(self._parameters) + ) diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index 2713ea46..fb7494a7 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -13,9 +13,10 @@ This file contains """ import logging +from dataclasses import field from enum import Enum -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch @@ -42,8 +43,27 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): """ Decoding function is a torch.nn.Module which takes the embedding of a location in space and transforms it into the required quantity (for example density and color). + + Members: + param_groups: dictionary where keys are names of individual parameters + or module members and values are the parameter group where the + parameter/member will be sorted to. "self" key is used to denote the + parameter group at the module level. Possible keys, including the "self" key + do not have to be defined. By default all parameters are put into "default" + parameter group and have the learning rate defined in the optimizer, + it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will be put to that parameter group + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, parameter with the same name as the key + will be put to that parameter group. """ + param_groups: Dict[str, str] = field(default_factory=lambda: {}) + def __post_init__(self): super().__init__() diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index b9f8c1bf..2135c822 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -808,6 +808,21 @@ class VoxelGridModule(Configurable, torch.nn.Module): with mean=init_mean and std=init_std. Default 0. hold_voxel_grid_as_parameters: if True components of the underlying voxel grids will be saved as parameters and therefore be trainable. Default True. + param_groups: dictionary where keys are names of individual parameters + or module members and values are the parameter group where the + parameter/member will be sorted to. "self" key is used to denote the + parameter group at the module level. Possible keys, including the "self" key + do not have to be defined. By default all parameters are put into "default" + parameter group and have the learning rate defined in the optimizer, + it can be overridden at the: + - module level with “self” key, all the parameters and child + module's parameters will be put to that parameter group + - member level, which is the same as if the `param_groups` in that + member has key=“self” and value equal to that parameter group. + This is useful if members do not have `param_groups`, for + example torch.nn.Linear. + - parameter level, parameter with the same name as the key + will be put to that parameter group. """ voxel_grid_class_type: str = "FullResolutionVoxelGrid" @@ -820,6 +835,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): init_mean: float = 0 hold_voxel_grid_as_parameters: bool = True + param_groups: Dict[str, str] = field(default_factory=lambda: {}) def __post_init__(self): super().__init__() diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 6fdcdb29..2b47de7f 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -19,7 +19,6 @@ from pytorch3d.implicitron.models.implicit_function.utils import ( from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( CPFactorizedVoxelGrid, FullResolutionVoxelGrid, - FullResolutionVoxelGridValues, VMFactorizedVoxelGrid, VoxelGridModule, )