From dc28b615ae00e482baad252b30bee9e6ce76365f Mon Sep 17 00:00:00 2001 From: David Novotny Date: Tue, 2 Feb 2021 05:42:59 -0800 Subject: [PATCH] Generation of test camera trajectories Summary: Implements methods for generating trajectories of test cameras. Reviewed By: nikhilaravi Differential Revision: D26100869 fbshipit-source-id: cf2b61a34d4c749cd8cba881e97f6c388e57d1f8 --- projects/nerf/nerf/eval_video_utils.py | 152 +++++++++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 projects/nerf/nerf/eval_video_utils.py diff --git a/projects/nerf/nerf/eval_video_utils.py b/projects/nerf/nerf/eval_video_utils.py new file mode 100644 index 00000000..b2849efd --- /dev/null +++ b/projects/nerf/nerf/eval_video_utils.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +from typing import Tuple + +import torch +from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras + + +def generate_eval_video_cameras( + train_dataset, + n_eval_cams: int = 100, + trajectory_type: str = "figure_eight", + trajectory_scale: float = 0.2, + scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), + up: Tuple[float, float, float] = (0.0, 0.0, 1.0), +) -> dict: + """ + Generate a camera trajectory for visualizing a NeRF model. + + Args: + train_dataset: The training dataset object. + n_eval_cams: Number of cameras in the trajectory. + trajectory_type: The type of the camera trajectory. Can be one of: + circular: Rotating around the center of the scene at a fixed radius. + figure_eight: Figure-of-8 trajectory around the center of the + central camera of the training dataset. + trefoil_knot: Same as 'figure_eight', but the trajectory has a shape + of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot). + figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape + of a figure-eight knot + (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). + trajectory_scale: The extent of the trajectory. + up: The "up" vector of the scene (=the normal of the scene floor). + Active for the `trajectory_type="circular"`. + scene_center: The center of the scene in world coordinates which all + the cameras from the generated trajectory look at. + Returns: + Dictionary of camera instances which can be used as the test dataset + """ + if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"): + cam_centers = torch.cat( + [e["camera"].get_camera_center() for e in train_dataset] + ) + # get the nearest camera center to the mean of centers + mean_camera_idx = ( + ((cam_centers - cam_centers.mean(dim=0)[None]) ** 2) + .sum(dim=1) + .min(dim=0) + .indices + ) + # generate the knot trajectory in canonical coords + time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams] + if trajectory_type == "trefoil_knot": + traj = _trefoil_knot(time) + elif trajectory_type == "figure_eight_knot": + traj = _figure_eight_knot(time) + elif trajectory_type == "figure_eight": + traj = _figure_eight(time) + traj[:, 2] -= traj[:, 2].max() + + # transform the canonical knot to the coord frame of the mean camera + traj_trans = ( + train_dataset[mean_camera_idx]["camera"] + .get_world_to_view_transform() + .inverse() + ) + traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale) + traj = traj_trans.transform_points(traj) + + elif trajectory_type == "circular": + cam_centers = torch.cat( + [e["camera"].get_camera_center() for e in train_dataset] + ) + + # fit plane to the camera centers + plane_mean = cam_centers.mean(dim=0) + cam_centers_c = cam_centers - plane_mean[None] + + if up is not None: + # us the up vector instad of the plane through the camera centers + plane_normal = torch.FloatTensor(up) + else: + cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0] + _, e_vec = torch.symeig(cov, eigenvectors=True) + plane_normal = e_vec[:, 0] + + plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1) + cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None] + + cov = ( + cam_centers_on_plane.t() @ cam_centers_on_plane + ) / cam_centers_on_plane.shape[0] + _, e_vec = torch.symeig(cov, eigenvectors=True) + traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean() + angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams) + traj = traj_radius * torch.stack( + (torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1 + ) + traj = traj @ e_vec.t() + plane_mean[None] + + else: + raise ValueError(f"Uknown trajectory_type {trajectory_type}.") + + # point all cameras towards the center of the scene + R, T = look_at_view_transform( + eye=traj, + at=(scene_center,), # (1, 3) + up=(up,), # (1, 3) + device=traj.device, + ) + + # get the average focal length and principal point + focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0) + p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0) + + # assemble the dataset + test_dataset = [ + { + "image": None, + "camera": PerspectiveCameras( + focal_length=focal[None], + principal_point=p0[None], + R=R_[None], + T=T_[None], + ), + "camera_idx": i, + } + for i, (R_, T_) in enumerate(zip(R, T)) + ] + + return test_dataset + + +def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5): + x = (2 + (2 * t).cos()) * (3 * t).cos() + y = (2 + (2 * t).cos()) * (3 * t).sin() + z = (4 * t).sin() * z_scale + return torch.stack((x, y, z), dim=-1) + + +def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5): + x = t.sin() + 2 * (2 * t).sin() + y = t.cos() - 2 * (2 * t).cos() + z = -(3 * t).sin() * z_scale + return torch.stack((x, y, z), dim=-1) + + +def _figure_eight(t: torch.Tensor, z_scale: float = 0.5): + x = t.cos() + y = (2 * t).sin() / 2 + z = t.sin() * z_scale + return torch.stack((x, y, z), dim=-1)