typing for trainer

Summary: Enable pyre checking of the trainer code.

Reviewed By: shapovalov

Differential Revision: D36545438

fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 4e87c2b7f1
commit 40fb189c29
3 changed files with 15 additions and 13 deletions

View File

@@ -220,7 +220,7 @@ def init_optimizer(
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float] = (0.9, 0.999),
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
@@ -257,6 +257,7 @@ def init_optimizer(
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
@@ -297,9 +298,6 @@ def init_optimizer(
for _ in range(last_epoch):
scheduler.step()
# Add the max epochs here
scheduler.max_epochs = max_epochs
optimizer.zero_grad()
return optimizer, scheduler
@@ -421,7 +419,7 @@ def trainvalidate(
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / total_norm}."
+ f" with coef {clip_grad / float(total_norm)}."
)
optimizer.step()