mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	typing for trainer
Summary: Enable pyre checking of the trainer code. Reviewed By: shapovalov Differential Revision: D36545438 fbshipit-source-id: db1ea8d1ade2da79a2956964eb0c7ba302fa40d1
This commit is contained in:
		
							parent
							
								
									4e87c2b7f1
								
							
						
					
					
						commit
						40fb189c29
					
				@ -220,7 +220,7 @@ def init_optimizer(
 | 
			
		||||
    lr: float = 0.0005,
 | 
			
		||||
    gamma: float = 0.1,
 | 
			
		||||
    momentum: float = 0.9,
 | 
			
		||||
    betas: Tuple[float] = (0.9, 0.999),
 | 
			
		||||
    betas: Tuple[float, ...] = (0.9, 0.999),
 | 
			
		||||
    milestones: tuple = (),
 | 
			
		||||
    max_epochs: int = 1000,
 | 
			
		||||
):
 | 
			
		||||
@ -257,6 +257,7 @@ def init_optimizer(
 | 
			
		||||
 | 
			
		||||
    # Get the parameters to optimize
 | 
			
		||||
    if hasattr(model, "_get_param_groups"):  # use the model function
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        p_groups = model._get_param_groups(lr, wd=weight_decay)
 | 
			
		||||
    else:
 | 
			
		||||
        allprm = [prm for prm in model.parameters() if prm.requires_grad]
 | 
			
		||||
@ -297,9 +298,6 @@ def init_optimizer(
 | 
			
		||||
    for _ in range(last_epoch):
 | 
			
		||||
        scheduler.step()
 | 
			
		||||
 | 
			
		||||
    # Add the max epochs here
 | 
			
		||||
    scheduler.max_epochs = max_epochs
 | 
			
		||||
 | 
			
		||||
    optimizer.zero_grad()
 | 
			
		||||
    return optimizer, scheduler
 | 
			
		||||
 | 
			
		||||
@ -421,7 +419,7 @@ def trainvalidate(
 | 
			
		||||
                if total_norm > clip_grad:
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        f"Clipping gradient: {total_norm}"
 | 
			
		||||
                        + f" with coef {clip_grad / total_norm}."
 | 
			
		||||
                        + f" with coef {clip_grad / float(total_norm)}."
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
 | 
			
		||||
@ -24,8 +24,7 @@ import torch.nn.functional as Fu
 | 
			
		||||
from experiment import init_model
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
 | 
			
		||||
from pytorch3d.implicitron.tools.configurable import get_default_args
 | 
			
		||||
@ -41,7 +40,7 @@ from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def render_sequence(
 | 
			
		||||
    dataset: JsonIndexDataset,
 | 
			
		||||
    dataset: DatasetBase,
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    model: torch.nn.Module,
 | 
			
		||||
    video_path,
 | 
			
		||||
@ -64,6 +63,12 @@ def render_sequence(
 | 
			
		||||
):
 | 
			
		||||
    if seed is None:
 | 
			
		||||
        seed = hash(sequence_name)
 | 
			
		||||
 | 
			
		||||
    if visdom_show_preds:
 | 
			
		||||
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
    else:
 | 
			
		||||
        viz = None
 | 
			
		||||
 | 
			
		||||
    print(f"Loading all data of sequence '{sequence_name}'.")
 | 
			
		||||
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
			
		||||
    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
 | 
			
		||||
@ -82,7 +87,7 @@ def render_sequence(
 | 
			
		||||
        up=up,
 | 
			
		||||
        focal_length=None,
 | 
			
		||||
        principal_point=torch.zeros(n_eval_cameras, 2),
 | 
			
		||||
        traj_offset_canonical=[0.0, 0.0, traj_offset],
 | 
			
		||||
        traj_offset_canonical=(0.0, 0.0, traj_offset),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # sample the source views reproducibly
 | 
			
		||||
@ -118,7 +123,6 @@ def render_sequence(
 | 
			
		||||
            if visdom_show_preds and (
 | 
			
		||||
                n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
 | 
			
		||||
            ):
 | 
			
		||||
                viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
                show_predictions(
 | 
			
		||||
                    preds_total,
 | 
			
		||||
                    sequence_name=batch.sequence_name[0],
 | 
			
		||||
 | 
			
		||||
@ -21,9 +21,9 @@ def generate_eval_video_cameras(
 | 
			
		||||
    trajectory_scale: float = 0.2,
 | 
			
		||||
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
 | 
			
		||||
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
 | 
			
		||||
    focal_length: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    principal_point: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    time: Optional[torch.FloatTensor] = None,
 | 
			
		||||
    focal_length: Optional[torch.Tensor] = None,
 | 
			
		||||
    principal_point: Optional[torch.Tensor] = None,
 | 
			
		||||
    time: Optional[torch.Tensor] = None,
 | 
			
		||||
    infer_up_as_plane_normal: bool = True,
 | 
			
		||||
    traj_offset: Optional[Tuple[float, float, float]] = None,
 | 
			
		||||
    traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user