diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 43c59712..00f06ad7 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -220,7 +220,7 @@ def init_optimizer( lr: float = 0.0005, gamma: float = 0.1, momentum: float = 0.9, - betas: Tuple[float] = (0.9, 0.999), + betas: Tuple[float, ...] = (0.9, 0.999), milestones: tuple = (), max_epochs: int = 1000, ): @@ -257,6 +257,7 @@ def init_optimizer( # Get the parameters to optimize if hasattr(model, "_get_param_groups"): # use the model function + # pyre-ignore[29] p_groups = model._get_param_groups(lr, wd=weight_decay) else: allprm = [prm for prm in model.parameters() if prm.requires_grad] @@ -297,9 +298,6 @@ def init_optimizer( for _ in range(last_epoch): scheduler.step() - # Add the max epochs here - scheduler.max_epochs = max_epochs - optimizer.zero_grad() return optimizer, scheduler @@ -421,7 +419,7 @@ def trainvalidate( if total_norm > clip_grad: logger.info( f"Clipping gradient: {total_norm}" - + f" with coef {clip_grad / total_norm}." + + f" with coef {clip_grad / float(total_norm)}." ) optimizer.step() diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index f8a9aa1f..8ba43e9b 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -24,8 +24,7 @@ import torch.nn.functional as Fu from experiment import init_model from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource -from pytorch3d.implicitron.dataset.dataset_base import FrameData -from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.models.base_model import EvaluationMode from pytorch3d.implicitron.tools.configurable import get_default_args @@ -41,7 +40,7 @@ from tqdm import tqdm def render_sequence( - dataset: JsonIndexDataset, + dataset: DatasetBase, sequence_name: str, model: torch.nn.Module, video_path, @@ -64,6 +63,12 @@ def render_sequence( ): if seed is None: seed = hash(sequence_name) + + if visdom_show_preds: + viz = get_visdom_connection(server=visdom_server, port=visdom_port) + else: + viz = None + print(f"Loading all data of sequence '{sequence_name}'.") seq_idx = list(dataset.sequence_indices_in_order(sequence_name)) train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) @@ -82,7 +87,7 @@ def render_sequence( up=up, focal_length=None, principal_point=torch.zeros(n_eval_cameras, 2), - traj_offset_canonical=[0.0, 0.0, traj_offset], + traj_offset_canonical=(0.0, 0.0, traj_offset), ) # sample the source views reproducibly @@ -118,7 +123,6 @@ def render_sequence( if visdom_show_preds and ( n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1 ): - viz = get_visdom_connection(server=visdom_server, port=visdom_port) show_predictions( preds_total, sequence_name=batch.sequence_name[0], diff --git a/pytorch3d/implicitron/tools/eval_video_trajectory.py b/pytorch3d/implicitron/tools/eval_video_trajectory.py index 0c0f65ae..e69eefdb 100644 --- a/pytorch3d/implicitron/tools/eval_video_trajectory.py +++ b/pytorch3d/implicitron/tools/eval_video_trajectory.py @@ -21,9 +21,9 @@ def generate_eval_video_cameras( trajectory_scale: float = 0.2, scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), up: Tuple[float, float, float] = (0.0, 0.0, 1.0), - focal_length: Optional[torch.FloatTensor] = None, - principal_point: Optional[torch.FloatTensor] = None, - time: Optional[torch.FloatTensor] = None, + focal_length: Optional[torch.Tensor] = None, + principal_point: Optional[torch.Tensor] = None, + time: Optional[torch.Tensor] = None, infer_up_as_plane_normal: bool = True, traj_offset: Optional[Tuple[float, float, float]] = None, traj_offset_canonical: Optional[Tuple[float, float, float]] = None,