mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Generation of test camera trajectories
Summary: Implements methods for generating trajectories of test cameras. Reviewed By: nikhilaravi Differential Revision: D26100869 fbshipit-source-id: cf2b61a34d4c749cd8cba881e97f6c388e57d1f8
This commit is contained in:
parent
9751f1f185
commit
dc28b615ae
152
projects/nerf/nerf/eval_video_utils.py
Normal file
152
projects/nerf/nerf/eval_video_utils.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
||||
|
||||
|
||||
def generate_eval_video_cameras(
|
||||
train_dataset,
|
||||
n_eval_cams: int = 100,
|
||||
trajectory_type: str = "figure_eight",
|
||||
trajectory_scale: float = 0.2,
|
||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a camera trajectory for visualizing a NeRF model.
|
||||
|
||||
Args:
|
||||
train_dataset: The training dataset object.
|
||||
n_eval_cams: Number of cameras in the trajectory.
|
||||
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||
circular: Rotating around the center of the scene at a fixed radius.
|
||||
figure_eight: Figure-of-8 trajectory around the center of the
|
||||
central camera of the training dataset.
|
||||
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a figure-eight knot
|
||||
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||
trajectory_scale: The extent of the trajectory.
|
||||
up: The "up" vector of the scene (=the normal of the scene floor).
|
||||
Active for the `trajectory_type="circular"`.
|
||||
scene_center: The center of the scene in world coordinates which all
|
||||
the cameras from the generated trajectory look at.
|
||||
Returns:
|
||||
Dictionary of camera instances which can be used as the test dataset
|
||||
"""
|
||||
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
|
||||
cam_centers = torch.cat(
|
||||
[e["camera"].get_camera_center() for e in train_dataset]
|
||||
)
|
||||
# get the nearest camera center to the mean of centers
|
||||
mean_camera_idx = (
|
||||
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
|
||||
.sum(dim=1)
|
||||
.min(dim=0)
|
||||
.indices
|
||||
)
|
||||
# generate the knot trajectory in canonical coords
|
||||
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
|
||||
if trajectory_type == "trefoil_knot":
|
||||
traj = _trefoil_knot(time)
|
||||
elif trajectory_type == "figure_eight_knot":
|
||||
traj = _figure_eight_knot(time)
|
||||
elif trajectory_type == "figure_eight":
|
||||
traj = _figure_eight(time)
|
||||
traj[:, 2] -= traj[:, 2].max()
|
||||
|
||||
# transform the canonical knot to the coord frame of the mean camera
|
||||
traj_trans = (
|
||||
train_dataset[mean_camera_idx]["camera"]
|
||||
.get_world_to_view_transform()
|
||||
.inverse()
|
||||
)
|
||||
traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
|
||||
traj = traj_trans.transform_points(traj)
|
||||
|
||||
elif trajectory_type == "circular":
|
||||
cam_centers = torch.cat(
|
||||
[e["camera"].get_camera_center() for e in train_dataset]
|
||||
)
|
||||
|
||||
# fit plane to the camera centers
|
||||
plane_mean = cam_centers.mean(dim=0)
|
||||
cam_centers_c = cam_centers - plane_mean[None]
|
||||
|
||||
if up is not None:
|
||||
# us the up vector instad of the plane through the camera centers
|
||||
plane_normal = torch.FloatTensor(up)
|
||||
else:
|
||||
cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
|
||||
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||
plane_normal = e_vec[:, 0]
|
||||
|
||||
plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
|
||||
cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]
|
||||
|
||||
cov = (
|
||||
cam_centers_on_plane.t() @ cam_centers_on_plane
|
||||
) / cam_centers_on_plane.shape[0]
|
||||
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
|
||||
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
|
||||
traj = traj_radius * torch.stack(
|
||||
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
|
||||
)
|
||||
traj = traj @ e_vec.t() + plane_mean[None]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
|
||||
|
||||
# point all cameras towards the center of the scene
|
||||
R, T = look_at_view_transform(
|
||||
eye=traj,
|
||||
at=(scene_center,), # (1, 3)
|
||||
up=(up,), # (1, 3)
|
||||
device=traj.device,
|
||||
)
|
||||
|
||||
# get the average focal length and principal point
|
||||
focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
|
||||
p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)
|
||||
|
||||
# assemble the dataset
|
||||
test_dataset = [
|
||||
{
|
||||
"image": None,
|
||||
"camera": PerspectiveCameras(
|
||||
focal_length=focal[None],
|
||||
principal_point=p0[None],
|
||||
R=R_[None],
|
||||
T=T_[None],
|
||||
),
|
||||
"camera_idx": i,
|
||||
}
|
||||
for i, (R_, T_) in enumerate(zip(R, T))
|
||||
]
|
||||
|
||||
return test_dataset
|
||||
|
||||
|
||||
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||
x = (2 + (2 * t).cos()) * (3 * t).cos()
|
||||
y = (2 + (2 * t).cos()) * (3 * t).sin()
|
||||
z = (4 * t).sin() * z_scale
|
||||
return torch.stack((x, y, z), dim=-1)
|
||||
|
||||
|
||||
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||
x = t.sin() + 2 * (2 * t).sin()
|
||||
y = t.cos() - 2 * (2 * t).cos()
|
||||
z = -(3 * t).sin() * z_scale
|
||||
return torch.stack((x, y, z), dim=-1)
|
||||
|
||||
|
||||
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
|
||||
x = t.cos()
|
||||
y = (2 * t).sin() / 2
|
||||
z = t.sin() * z_scale
|
||||
return torch.stack((x, y, z), dim=-1)
|
Loading…
x
Reference in New Issue
Block a user