typing for trainer

Summary: Enable pyre checking of the trainer code.

Reviewed By: shapovalov

Differential Revision: D36545438

fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
Jeremy Reizenstein 2022-07-06 07:13:41 -07:00 committed by Facebook GitHub Bot
parent 4e87c2b7f1
commit 40fb189c29
3 changed files with 15 additions and 13 deletions

View File

@ -220,7 +220,7 @@ def init_optimizer(
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float] = (0.9, 0.999),
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
@ -257,6 +257,7 @@ def init_optimizer(
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
@ -297,9 +298,6 @@ def init_optimizer(
for _ in range(last_epoch):
scheduler.step()
# Add the max epochs here
scheduler.max_epochs = max_epochs
optimizer.zero_grad()
return optimizer, scheduler
@ -421,7 +419,7 @@ def trainvalidate(
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / total_norm}."
+ f" with coef {clip_grad / float(total_norm)}."
)
optimizer.step()

View File

@ -24,8 +24,7 @@ import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
@ -41,7 +40,7 @@ from tqdm import tqdm
def render_sequence(
dataset: JsonIndexDataset,
dataset: DatasetBase,
sequence_name: str,
model: torch.nn.Module,
video_path,
@ -64,6 +63,12 @@ def render_sequence(
):
if seed is None:
seed = hash(sequence_name)
if visdom_show_preds:
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
else:
viz = None
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
@ -82,7 +87,7 @@ def render_sequence(
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
traj_offset_canonical=(0.0, 0.0, traj_offset),
)
# sample the source views reproducibly
@ -118,7 +123,6 @@ def render_sequence(
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],

View File

@ -21,9 +21,9 @@ def generate_eval_video_cameras(
trajectory_scale: float = 0.2,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
focal_length: Optional[torch.FloatTensor] = None,
principal_point: Optional[torch.FloatTensor] = None,
time: Optional[torch.FloatTensor] = None,
focal_length: Optional[torch.Tensor] = None,
principal_point: Optional[torch.Tensor] = None,
time: Optional[torch.Tensor] = None,
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,