pytorch3d/projects/nerf/nerf/eval_video_utils.py
Tim Hatch 34bbb3ad32 apply import merging for fbcode/vision/fair (2 of 2)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Reviewed By: bottler

Differential Revision: D35553814

fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
2022-04-13 06:51:33 -07:00

159 lines
5.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Tuple
import torch
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
from torch.utils.data.dataset import Dataset
def generate_eval_video_cameras(
train_dataset,
n_eval_cams: int = 100,
trajectory_type: str = "figure_eight",
trajectory_scale: float = 0.2,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
) -> Dataset[torch.Tensor]:
"""
Generate a camera trajectory for visualizing a NeRF model.
Args:
train_dataset: The training dataset object.
n_eval_cams: Number of cameras in the trajectory.
trajectory_type: The type of the camera trajectory. Can be one of:
circular: Rotating around the center of the scene at a fixed radius.
figure_eight: Figure-of-8 trajectory around the center of the
central camera of the training dataset.
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
of a figure-eight knot
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
trajectory_scale: The extent of the trajectory.
up: The "up" vector of the scene (=the normal of the scene floor).
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = torch.cat(
[e["camera"].get_camera_center() for e in train_dataset]
)
# get the nearest camera center to the mean of centers
mean_camera_idx = (
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
.sum(dim=1)
.min(dim=0)
.indices
)
# generate the knot trajectory in canonical coords
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
if trajectory_type == "trefoil_knot":
traj = _trefoil_knot(time)
elif trajectory_type == "figure_eight_knot":
traj = _figure_eight_knot(time)
elif trajectory_type == "figure_eight":
traj = _figure_eight(time)
traj[:, 2] -= traj[:, 2].max()
# transform the canonical knot to the coord frame of the mean camera
traj_trans = (
train_dataset[mean_camera_idx]["camera"]
.get_world_to_view_transform()
.inverse()
)
traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
traj = traj_trans.transform_points(traj)
elif trajectory_type == "circular":
cam_centers = torch.cat(
[e["camera"].get_camera_center() for e in train_dataset]
)
# fit plane to the camera centers
plane_mean = cam_centers.mean(dim=0)
cam_centers_c = cam_centers - plane_mean[None]
if up is not None:
# us the up vector instead of the plane through the camera centers
plane_normal = torch.FloatTensor(up)
else:
cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
_, e_vec = torch.symeig(cov, eigenvectors=True)
plane_normal = e_vec[:, 0]
plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]
cov = (
cam_centers_on_plane.t() @ cam_centers_on_plane
) / cam_centers_on_plane.shape[0]
_, e_vec = torch.symeig(cov, eigenvectors=True)
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
traj = traj_radius * torch.stack(
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
)
traj = traj @ e_vec.t() + plane_mean[None]
else:
raise ValueError(f"Unknown trajectory_type {trajectory_type}.")
# point all cameras towards the center of the scene
R, T = look_at_view_transform(
eye=traj,
at=(scene_center,), # (1, 3)
up=(up,), # (1, 3)
device=traj.device,
)
# get the average focal length and principal point
focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)
# assemble the dataset
test_dataset = [
{
"image": None,
"camera": PerspectiveCameras(
focal_length=focal[None],
principal_point=p0[None],
R=R_[None],
T=T_[None],
),
"camera_idx": i,
}
for i, (R_, T_) in enumerate(zip(R, T))
]
return test_dataset
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
x = (2 + (2 * t).cos()) * (3 * t).cos()
y = (2 + (2 * t).cos()) * (3 * t).sin()
z = (4 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
x = t.sin() + 2 * (2 * t).sin()
y = t.cos() - 2 * (2 * t).cos()
z = -(3 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
x = t.cos()
y = (2 * t).sin() / 2
z = t.sin() * z_scale
return torch.stack((x, y, z), dim=-1)