mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 23:00:34 +08:00
typing for trainer
Summary: Enable pyre checking of the trainer code. Reviewed By: shapovalov Differential Revision: D36545438 fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4e87c2b7f1
commit
40fb189c29
@@ -220,7 +220,7 @@ def init_optimizer(
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float] = (0.9, 0.999),
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: tuple = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
@@ -257,6 +257,7 @@ def init_optimizer(
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
@@ -297,9 +298,6 @@ def init_optimizer(
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
# Add the max epochs here
|
||||
scheduler.max_epochs = max_epochs
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
@@ -421,7 +419,7 @@ def trainvalidate(
|
||||
if total_norm > clip_grad:
|
||||
logger.info(
|
||||
f"Clipping gradient: {total_norm}"
|
||||
+ f" with coef {clip_grad / total_norm}."
|
||||
+ f" with coef {clip_grad / float(total_norm)}."
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user