mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Generation of test camera trajectories
Summary: Implements methods for generating trajectories of test cameras. Reviewed By: nikhilaravi Differential Revision: D26100869 fbshipit-source-id: cf2b61a34d4c749cd8cba881e97f6c388e57d1f8
This commit is contained in:
		
							parent
							
								
									9751f1f185
								
							
						
					
					
						commit
						dc28b615ae
					
				
							
								
								
									
										152
									
								
								projects/nerf/nerf/eval_video_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								projects/nerf/nerf/eval_video_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,152 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
import math
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_eval_video_cameras(
 | 
			
		||||
    train_dataset,
 | 
			
		||||
    n_eval_cams: int = 100,
 | 
			
		||||
    trajectory_type: str = "figure_eight",
 | 
			
		||||
    trajectory_scale: float = 0.2,
 | 
			
		||||
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
 | 
			
		||||
    up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
 | 
			
		||||
) -> dict:
 | 
			
		||||
    """
 | 
			
		||||
    Generate a camera trajectory for visualizing a NeRF model.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        train_dataset: The training dataset object.
 | 
			
		||||
        n_eval_cams: Number of cameras in the trajectory.
 | 
			
		||||
        trajectory_type: The type of the camera trajectory. Can be one of:
 | 
			
		||||
            circular: Rotating around the center of the scene at a fixed radius.
 | 
			
		||||
            figure_eight: Figure-of-8 trajectory around the center of the
 | 
			
		||||
                central camera of the training dataset.
 | 
			
		||||
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
 | 
			
		||||
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a figure-eight knot
 | 
			
		||||
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
 | 
			
		||||
        trajectory_scale: The extent of the trajectory.
 | 
			
		||||
        up: The "up" vector of the scene (=the normal of the scene floor).
 | 
			
		||||
            Active for the `trajectory_type="circular"`.
 | 
			
		||||
        scene_center: The center of the scene in world coordinates which all
 | 
			
		||||
            the cameras from the generated trajectory look at.
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dictionary of camera instances which can be used as the test dataset
 | 
			
		||||
    """
 | 
			
		||||
    if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
 | 
			
		||||
        cam_centers = torch.cat(
 | 
			
		||||
            [e["camera"].get_camera_center() for e in train_dataset]
 | 
			
		||||
        )
 | 
			
		||||
        # get the nearest camera center to the mean of centers
 | 
			
		||||
        mean_camera_idx = (
 | 
			
		||||
            ((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
 | 
			
		||||
            .sum(dim=1)
 | 
			
		||||
            .min(dim=0)
 | 
			
		||||
            .indices
 | 
			
		||||
        )
 | 
			
		||||
        # generate the knot trajectory in canonical coords
 | 
			
		||||
        time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
 | 
			
		||||
        if trajectory_type == "trefoil_knot":
 | 
			
		||||
            traj = _trefoil_knot(time)
 | 
			
		||||
        elif trajectory_type == "figure_eight_knot":
 | 
			
		||||
            traj = _figure_eight_knot(time)
 | 
			
		||||
        elif trajectory_type == "figure_eight":
 | 
			
		||||
            traj = _figure_eight(time)
 | 
			
		||||
        traj[:, 2] -= traj[:, 2].max()
 | 
			
		||||
 | 
			
		||||
        # transform the canonical knot to the coord frame of the mean camera
 | 
			
		||||
        traj_trans = (
 | 
			
		||||
            train_dataset[mean_camera_idx]["camera"]
 | 
			
		||||
            .get_world_to_view_transform()
 | 
			
		||||
            .inverse()
 | 
			
		||||
        )
 | 
			
		||||
        traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
 | 
			
		||||
        traj = traj_trans.transform_points(traj)
 | 
			
		||||
 | 
			
		||||
    elif trajectory_type == "circular":
 | 
			
		||||
        cam_centers = torch.cat(
 | 
			
		||||
            [e["camera"].get_camera_center() for e in train_dataset]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # fit plane to the camera centers
 | 
			
		||||
        plane_mean = cam_centers.mean(dim=0)
 | 
			
		||||
        cam_centers_c = cam_centers - plane_mean[None]
 | 
			
		||||
 | 
			
		||||
        if up is not None:
 | 
			
		||||
            # us the up vector instad of the plane through the camera centers
 | 
			
		||||
            plane_normal = torch.FloatTensor(up)
 | 
			
		||||
        else:
 | 
			
		||||
            cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
 | 
			
		||||
            _, e_vec = torch.symeig(cov, eigenvectors=True)
 | 
			
		||||
            plane_normal = e_vec[:, 0]
 | 
			
		||||
 | 
			
		||||
        plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
 | 
			
		||||
        cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]
 | 
			
		||||
 | 
			
		||||
        cov = (
 | 
			
		||||
            cam_centers_on_plane.t() @ cam_centers_on_plane
 | 
			
		||||
        ) / cam_centers_on_plane.shape[0]
 | 
			
		||||
        _, e_vec = torch.symeig(cov, eigenvectors=True)
 | 
			
		||||
        traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
 | 
			
		||||
        angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
 | 
			
		||||
        traj = traj_radius * torch.stack(
 | 
			
		||||
            (torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
 | 
			
		||||
        )
 | 
			
		||||
        traj = traj @ e_vec.t() + plane_mean[None]
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
 | 
			
		||||
 | 
			
		||||
    # point all cameras towards the center of the scene
 | 
			
		||||
    R, T = look_at_view_transform(
 | 
			
		||||
        eye=traj,
 | 
			
		||||
        at=(scene_center,),  # (1, 3)
 | 
			
		||||
        up=(up,),  # (1, 3)
 | 
			
		||||
        device=traj.device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # get the average focal length and principal point
 | 
			
		||||
    focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
 | 
			
		||||
    p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)
 | 
			
		||||
 | 
			
		||||
    # assemble the dataset
 | 
			
		||||
    test_dataset = [
 | 
			
		||||
        {
 | 
			
		||||
            "image": None,
 | 
			
		||||
            "camera": PerspectiveCameras(
 | 
			
		||||
                focal_length=focal[None],
 | 
			
		||||
                principal_point=p0[None],
 | 
			
		||||
                R=R_[None],
 | 
			
		||||
                T=T_[None],
 | 
			
		||||
            ),
 | 
			
		||||
            "camera_idx": i,
 | 
			
		||||
        }
 | 
			
		||||
        for i, (R_, T_) in enumerate(zip(R, T))
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return test_dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
 | 
			
		||||
    x = (2 + (2 * t).cos()) * (3 * t).cos()
 | 
			
		||||
    y = (2 + (2 * t).cos()) * (3 * t).sin()
 | 
			
		||||
    z = (4 * t).sin() * z_scale
 | 
			
		||||
    return torch.stack((x, y, z), dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
 | 
			
		||||
    x = t.sin() + 2 * (2 * t).sin()
 | 
			
		||||
    y = t.cos() - 2 * (2 * t).cos()
 | 
			
		||||
    z = -(3 * t).sin() * z_scale
 | 
			
		||||
    return torch.stack((x, y, z), dim=-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
 | 
			
		||||
    x = t.cos()
 | 
			
		||||
    y = (2 * t).sin() / 2
 | 
			
		||||
    z = t.sin() * z_scale
 | 
			
		||||
    return torch.stack((x, y, z), dim=-1)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user