mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
typing for trainer
Summary: Enable pyre checking of the trainer code. Reviewed By: shapovalov Differential Revision: D36545438 fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
parent
4e87c2b7f1
commit
40fb189c29
@ -220,7 +220,7 @@ def init_optimizer(
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float] = (0.9, 0.999),
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: tuple = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
@ -257,6 +257,7 @@ def init_optimizer(
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
@ -297,9 +298,6 @@ def init_optimizer(
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
# Add the max epochs here
|
||||
scheduler.max_epochs = max_epochs
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
@ -421,7 +419,7 @@ def trainvalidate(
|
||||
if total_norm > clip_grad:
|
||||
logger.info(
|
||||
f"Clipping gradient: {total_norm}"
|
||||
+ f" with coef {clip_grad / total_norm}."
|
||||
+ f" with coef {clip_grad / float(total_norm)}."
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
@ -24,8 +24,7 @@ import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||
@ -41,7 +40,7 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
def render_sequence(
|
||||
dataset: JsonIndexDataset,
|
||||
dataset: DatasetBase,
|
||||
sequence_name: str,
|
||||
model: torch.nn.Module,
|
||||
video_path,
|
||||
@ -64,6 +63,12 @@ def render_sequence(
|
||||
):
|
||||
if seed is None:
|
||||
seed = hash(sequence_name)
|
||||
|
||||
if visdom_show_preds:
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
else:
|
||||
viz = None
|
||||
|
||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
@ -82,7 +87,7 @@ def render_sequence(
|
||||
up=up,
|
||||
focal_length=None,
|
||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
||||
traj_offset_canonical=[0.0, 0.0, traj_offset],
|
||||
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||
)
|
||||
|
||||
# sample the source views reproducibly
|
||||
@ -118,7 +123,6 @@ def render_sequence(
|
||||
if visdom_show_preds and (
|
||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
||||
):
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
show_predictions(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
|
@ -21,9 +21,9 @@ def generate_eval_video_cameras(
|
||||
trajectory_scale: float = 0.2,
|
||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||
focal_length: Optional[torch.FloatTensor] = None,
|
||||
principal_point: Optional[torch.FloatTensor] = None,
|
||||
time: Optional[torch.FloatTensor] = None,
|
||||
focal_length: Optional[torch.Tensor] = None,
|
||||
principal_point: Optional[torch.Tensor] = None,
|
||||
time: Optional[torch.Tensor] = None,
|
||||
infer_up_as_plane_normal: bool = True,
|
||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user