typing for trainer

Summary: Enable pyre checking of the trainer code.

Reviewed By: shapovalov

Differential Revision: D36545438

fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 4e87c2b7f1
commit 40fb189c29
3 changed files with 15 additions and 13 deletions

View File

@@ -24,8 +24,7 @@ import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
@@ -41,7 +40,7 @@ from tqdm import tqdm
def render_sequence(
dataset: JsonIndexDataset,
dataset: DatasetBase,
sequence_name: str,
model: torch.nn.Module,
video_path,
@@ -64,6 +63,12 @@ def render_sequence(
):
if seed is None:
seed = hash(sequence_name)
if visdom_show_preds:
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
else:
viz = None
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
@@ -82,7 +87,7 @@ def render_sequence(
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
traj_offset_canonical=(0.0, 0.0, traj_offset),
)
# sample the source views reproducibly
@@ -118,7 +123,6 @@ def render_sequence(
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],