mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
typing for trainer
Summary: Enable pyre checking of the trainer code. Reviewed By: shapovalov Differential Revision: D36545438 fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
parent
4e87c2b7f1
commit
40fb189c29
@ -220,7 +220,7 @@ def init_optimizer(
|
|||||||
lr: float = 0.0005,
|
lr: float = 0.0005,
|
||||||
gamma: float = 0.1,
|
gamma: float = 0.1,
|
||||||
momentum: float = 0.9,
|
momentum: float = 0.9,
|
||||||
betas: Tuple[float] = (0.9, 0.999),
|
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||||
milestones: tuple = (),
|
milestones: tuple = (),
|
||||||
max_epochs: int = 1000,
|
max_epochs: int = 1000,
|
||||||
):
|
):
|
||||||
@ -257,6 +257,7 @@ def init_optimizer(
|
|||||||
|
|
||||||
# Get the parameters to optimize
|
# Get the parameters to optimize
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
|
# pyre-ignore[29]
|
||||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||||
else:
|
else:
|
||||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||||
@ -297,9 +298,6 @@ def init_optimizer(
|
|||||||
for _ in range(last_epoch):
|
for _ in range(last_epoch):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# Add the max epochs here
|
|
||||||
scheduler.max_epochs = max_epochs
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
@ -421,7 +419,7 @@ def trainvalidate(
|
|||||||
if total_norm > clip_grad:
|
if total_norm > clip_grad:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Clipping gradient: {total_norm}"
|
f"Clipping gradient: {total_norm}"
|
||||||
+ f" with coef {clip_grad / total_norm}."
|
+ f" with coef {clip_grad / float(total_norm)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -24,8 +24,7 @@ import torch.nn.functional as Fu
|
|||||||
from experiment import init_model
|
from experiment import init_model
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||||
@ -41,7 +40,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
|
|
||||||
def render_sequence(
|
def render_sequence(
|
||||||
dataset: JsonIndexDataset,
|
dataset: DatasetBase,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
video_path,
|
video_path,
|
||||||
@ -64,6 +63,12 @@ def render_sequence(
|
|||||||
):
|
):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = hash(sequence_name)
|
seed = hash(sequence_name)
|
||||||
|
|
||||||
|
if visdom_show_preds:
|
||||||
|
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||||
|
else:
|
||||||
|
viz = None
|
||||||
|
|
||||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||||
@ -82,7 +87,7 @@ def render_sequence(
|
|||||||
up=up,
|
up=up,
|
||||||
focal_length=None,
|
focal_length=None,
|
||||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
principal_point=torch.zeros(n_eval_cameras, 2),
|
||||||
traj_offset_canonical=[0.0, 0.0, traj_offset],
|
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
# sample the source views reproducibly
|
# sample the source views reproducibly
|
||||||
@ -118,7 +123,6 @@ def render_sequence(
|
|||||||
if visdom_show_preds and (
|
if visdom_show_preds and (
|
||||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
||||||
):
|
):
|
||||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
|
||||||
show_predictions(
|
show_predictions(
|
||||||
preds_total,
|
preds_total,
|
||||||
sequence_name=batch.sequence_name[0],
|
sequence_name=batch.sequence_name[0],
|
||||||
|
@ -21,9 +21,9 @@ def generate_eval_video_cameras(
|
|||||||
trajectory_scale: float = 0.2,
|
trajectory_scale: float = 0.2,
|
||||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||||
focal_length: Optional[torch.FloatTensor] = None,
|
focal_length: Optional[torch.Tensor] = None,
|
||||||
principal_point: Optional[torch.FloatTensor] = None,
|
principal_point: Optional[torch.Tensor] = None,
|
||||||
time: Optional[torch.FloatTensor] = None,
|
time: Optional[torch.Tensor] = None,
|
||||||
infer_up_as_plane_normal: bool = True,
|
infer_up_as_plane_normal: bool = True,
|
||||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user