mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Generation of test camera trajectories
Summary: Implements methods for generating trajectories of test cameras. Reviewed By: nikhilaravi Differential Revision: D26100869 fbshipit-source-id: cf2b61a34d4c749cd8cba881e97f6c388e57d1f8
This commit is contained in:
parent
9751f1f185
commit
dc28b615ae
152
projects/nerf/nerf/eval_video_utils.py
Normal file
152
projects/nerf/nerf/eval_video_utils.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
||||||
|
|
||||||
|
|
||||||
|
def generate_eval_video_cameras(
|
||||||
|
train_dataset,
|
||||||
|
n_eval_cams: int = 100,
|
||||||
|
trajectory_type: str = "figure_eight",
|
||||||
|
trajectory_scale: float = 0.2,
|
||||||
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
|
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Generate a camera trajectory for visualizing a NeRF model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_dataset: The training dataset object.
|
||||||
|
n_eval_cams: Number of cameras in the trajectory.
|
||||||
|
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||||
|
circular: Rotating around the center of the scene at a fixed radius.
|
||||||
|
figure_eight: Figure-of-8 trajectory around the center of the
|
||||||
|
central camera of the training dataset.
|
||||||
|
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||||
|
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a figure-eight knot
|
||||||
|
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||||
|
trajectory_scale: The extent of the trajectory.
|
||||||
|
up: The "up" vector of the scene (=the normal of the scene floor).
|
||||||
|
Active for the `trajectory_type="circular"`.
|
||||||
|
scene_center: The center of the scene in world coordinates which all
|
||||||
|
the cameras from the generated trajectory look at.
|
||||||
|
Returns:
|
||||||
|
Dictionary of camera instances which can be used as the test dataset
|
||||||
|
"""
|
||||||
|
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
|
||||||
|
cam_centers = torch.cat(
|
||||||
|
[e["camera"].get_camera_center() for e in train_dataset]
|
||||||
|
)
|
||||||
|
# get the nearest camera center to the mean of centers
|
||||||
|
mean_camera_idx = (
|
||||||
|
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
|
||||||
|
.sum(dim=1)
|
||||||
|
.min(dim=0)
|
||||||
|
.indices
|
||||||
|
)
|
||||||
|
# generate the knot trajectory in canonical coords
|
||||||
|
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
|
||||||
|
if trajectory_type == "trefoil_knot":
|
||||||
|
traj = _trefoil_knot(time)
|
||||||
|
elif trajectory_type == "figure_eight_knot":
|
||||||
|
traj = _figure_eight_knot(time)
|
||||||
|
elif trajectory_type == "figure_eight":
|
||||||
|
traj = _figure_eight(time)
|
||||||
|
traj[:, 2] -= traj[:, 2].max()
|
||||||
|
|
||||||
|
# transform the canonical knot to the coord frame of the mean camera
|
||||||
|
traj_trans = (
|
||||||
|
train_dataset[mean_camera_idx]["camera"]
|
||||||
|
.get_world_to_view_transform()
|
||||||
|
.inverse()
|
||||||
|
)
|
||||||
|
traj_trans = traj_trans.scale(cam_centers.std(dim=0).mean() * trajectory_scale)
|
||||||
|
traj = traj_trans.transform_points(traj)
|
||||||
|
|
||||||
|
elif trajectory_type == "circular":
|
||||||
|
cam_centers = torch.cat(
|
||||||
|
[e["camera"].get_camera_center() for e in train_dataset]
|
||||||
|
)
|
||||||
|
|
||||||
|
# fit plane to the camera centers
|
||||||
|
plane_mean = cam_centers.mean(dim=0)
|
||||||
|
cam_centers_c = cam_centers - plane_mean[None]
|
||||||
|
|
||||||
|
if up is not None:
|
||||||
|
# us the up vector instad of the plane through the camera centers
|
||||||
|
plane_normal = torch.FloatTensor(up)
|
||||||
|
else:
|
||||||
|
cov = (cam_centers_c.t() @ cam_centers_c) / cam_centers_c.shape[0]
|
||||||
|
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||||
|
plane_normal = e_vec[:, 0]
|
||||||
|
|
||||||
|
plane_dist = (plane_normal[None] * cam_centers_c).sum(dim=-1)
|
||||||
|
cam_centers_on_plane = cam_centers_c - plane_dist[:, None] * plane_normal[None]
|
||||||
|
|
||||||
|
cov = (
|
||||||
|
cam_centers_on_plane.t() @ cam_centers_on_plane
|
||||||
|
) / cam_centers_on_plane.shape[0]
|
||||||
|
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||||
|
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
|
||||||
|
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
|
||||||
|
traj = traj_radius * torch.stack(
|
||||||
|
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
|
||||||
|
)
|
||||||
|
traj = traj @ e_vec.t() + plane_mean[None]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
|
||||||
|
|
||||||
|
# point all cameras towards the center of the scene
|
||||||
|
R, T = look_at_view_transform(
|
||||||
|
eye=traj,
|
||||||
|
at=(scene_center,), # (1, 3)
|
||||||
|
up=(up,), # (1, 3)
|
||||||
|
device=traj.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the average focal length and principal point
|
||||||
|
focal = torch.cat([e["camera"].focal_length for e in train_dataset]).mean(dim=0)
|
||||||
|
p0 = torch.cat([e["camera"].principal_point for e in train_dataset]).mean(dim=0)
|
||||||
|
|
||||||
|
# assemble the dataset
|
||||||
|
test_dataset = [
|
||||||
|
{
|
||||||
|
"image": None,
|
||||||
|
"camera": PerspectiveCameras(
|
||||||
|
focal_length=focal[None],
|
||||||
|
principal_point=p0[None],
|
||||||
|
R=R_[None],
|
||||||
|
T=T_[None],
|
||||||
|
),
|
||||||
|
"camera_idx": i,
|
||||||
|
}
|
||||||
|
for i, (R_, T_) in enumerate(zip(R, T))
|
||||||
|
]
|
||||||
|
|
||||||
|
return test_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||||
|
x = (2 + (2 * t).cos()) * (3 * t).cos()
|
||||||
|
y = (2 + (2 * t).cos()) * (3 * t).sin()
|
||||||
|
z = (4 * t).sin() * z_scale
|
||||||
|
return torch.stack((x, y, z), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||||
|
x = t.sin() + 2 * (2 * t).sin()
|
||||||
|
y = t.cos() - 2 * (2 * t).cos()
|
||||||
|
z = -(3 * t).sin() * z_scale
|
||||||
|
return torch.stack((x, y, z), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
|
||||||
|
x = t.cos()
|
||||||
|
y = (2 * t).sin() / 2
|
||||||
|
z = t.sin() * z_scale
|
||||||
|
return torch.stack((x, y, z), dim=-1)
|
Loading…
x
Reference in New Issue
Block a user