mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
different learning rate for different parts
Summary:
Adds the ability to have different learning rates for different parts of the model. The trainable parts of the implicitron have a new member
param_groups: dictionary where keys are names of individual parameters,
or module’s members and values are the parameter group where the
parameter/member will be sorted to. "self" key is used to denote the
parameter group at the module level. Possible keys, including the "self" key
do not have to be defined. By default all parameters are put into "default"
parameter group and have the learning rate defined in the optimizer,
it can be overriden at the:
- module level with “self” key, all the parameters and child
module s parameters will be put to that parameter group
- member level, which is the same as if the `param_groups` in that
member has key=“self” and value equal to that parameter group.
This is useful if members do not have `param_groups`, for
example torch.nn.Linear.
- parameter level, parameter with the same name as the key
will be put to that parameter group.
And in the optimizer factory, parameters and their learning rates are recursively gathered.
Reviewed By: shapovalov
Differential Revision: D40145802
fbshipit-source-id: 631c02b8d79ee1c0eb4c31e6e42dbd3d2882078a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a819ecb00b
commit
fe5bdb2fb5
@@ -7,7 +7,9 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch.optim
|
||||
|
||||
@@ -64,6 +66,12 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||
foreach: Whether to use new "foreach" implementation of optimizer where
|
||||
available (e.g. requires PyTorch 1.12.0 for Adam)
|
||||
group_learning_rates: Parameters or modules can be assigned to parameter
|
||||
groups. This dictionary has names of those parameter groups as keys
|
||||
and learning rates as values. All parameter group names have to be
|
||||
defined in this dictionary. Parameters which do not have predefined
|
||||
parameter group are put into "default" parameter group which has
|
||||
`lr` as its learning rate.
|
||||
"""
|
||||
|
||||
betas: Tuple[float, ...] = (0.9, 0.999)
|
||||
@@ -78,6 +86,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
linear_exponential_lr_milestone: int = 200
|
||||
linear_exponential_start_gamma: float = 0.1
|
||||
foreach: Optional[bool] = True
|
||||
group_learning_rates: Dict[str, float] = field(default_factory=lambda: {})
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -115,8 +124,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": self.lr}]
|
||||
p_groups = [
|
||||
{"params": params, "lr": self._get_group_learning_rate(group)}
|
||||
for group, params in self._get_param_groups(model).items()
|
||||
]
|
||||
|
||||
# Intialize the optimizer
|
||||
optimizer_kwargs: Dict[str, Any] = {
|
||||
@@ -233,3 +244,82 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
else:
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
||||
def _get_param_groups(
|
||||
self, module: torch.nn.Module
|
||||
) -> Dict[str, List[torch.nn.Parameter]]:
|
||||
"""
|
||||
Recursively visits all the modules inside the `module` and sorts all the
|
||||
parameters in parameter groups.
|
||||
|
||||
Uses `param_groups` dictionary member, where keys are names of individual
|
||||
parameters or module members and values are the names of the parameter groups
|
||||
for those parameters or members. "self" key is used to denote the parameter groups
|
||||
at the module level. Possible keys, including the "self" key do not have to
|
||||
be defined. By default all parameters have the learning rate defined in the
|
||||
optimizer. This can be overridden by setting the parameter group in `param_groups`
|
||||
member of a specific module, it can be overridden at the:
|
||||
- module level with “self” key, all the parameters and child
|
||||
module's parameters will inherit it
|
||||
- member level, which is the same as if the `param_groups` in that
|
||||
member has key=“self” and value equal to that parameter group.
|
||||
This is useful if members do not have `param_groups`, for
|
||||
example torch.nn.Linear.
|
||||
- parameter level, only parameter with the same name as the key
|
||||
will have it.
|
||||
|
||||
Args:
|
||||
module: module from which to extract the parameters and their parameter
|
||||
groups
|
||||
Returns:
|
||||
dictionary with parameter groups as keys and lists of parameters as values
|
||||
"""
|
||||
|
||||
param_groups = defaultdict(list)
|
||||
|
||||
def traverse(module, default_group):
|
||||
# If key self is defined in param_groups then chenge the default param
|
||||
# group for all parameters and children in the module.
|
||||
if hasattr(module, "param_groups") and "self" in module.param_groups:
|
||||
default_group = module.param_groups["self"]
|
||||
|
||||
# Collect all the parameters that are directly inside the `module`,
|
||||
# they will be in the default param group if they don't have
|
||||
# defined group.
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.requires_grad:
|
||||
if hasattr(module, "param_groups") and name in module.param_groups:
|
||||
param_groups[module.param_groups[name]].append(param)
|
||||
else:
|
||||
param_groups[default_group].append(param)
|
||||
|
||||
# If children have defined default param group then use it else pass
|
||||
# own default.
|
||||
for child_name, child in module.named_children():
|
||||
if (
|
||||
hasattr(module, "param_groups")
|
||||
and child_name in module.param_groups
|
||||
):
|
||||
traverse(child, module.param_groups[child_name])
|
||||
else:
|
||||
traverse(child, default_group)
|
||||
|
||||
traverse(module, "default")
|
||||
return param_groups
|
||||
|
||||
def _get_group_learning_rate(self, group_name: str) -> float:
|
||||
"""
|
||||
Wraps the `group_learning_rates` dictionary providing errors and returns
|
||||
`self.lr` for "default" group_name.
|
||||
|
||||
Args:
|
||||
group_name: a string representing the name of the group
|
||||
Returns:
|
||||
learning rate for a specific group
|
||||
"""
|
||||
if group_name == "default":
|
||||
return self.lr
|
||||
lr = self.group_learning_rates.get(group_name, None)
|
||||
if lr is None:
|
||||
raise ValueError(f"no learning rate given for group {group_name}")
|
||||
return lr
|
||||
|
||||
Reference in New Issue
Block a user