mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Implements the NeRF raysampler. Reviewed By: nikhilaravi Differential Revision: D25684403 fbshipit-source-id: 616a60f047c79479f60a6a75d214f87cbfb06d28
123 lines
3.6 KiB
Python
123 lines
3.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
|
|
from pytorch3d.renderer import PerspectiveCameras
|
|
from pytorch3d.transforms.rotation_conversions import random_rotations
|
|
|
|
|
|
class TestRaysampler(unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
torch.manual_seed(42)
|
|
|
|
def test_raysampler_caching(self, batch_size=10):
|
|
"""
|
|
Tests the consistency of the NeRF raysampler caching.
|
|
"""
|
|
|
|
raysampler = NeRFRaysampler(
|
|
min_x=0.0,
|
|
max_x=10.0,
|
|
min_y=0.0,
|
|
max_y=10.0,
|
|
n_pts_per_ray=10,
|
|
min_depth=0.1,
|
|
max_depth=10.0,
|
|
n_rays_per_image=12,
|
|
image_width=10,
|
|
image_height=10,
|
|
stratified=False,
|
|
stratified_test=False,
|
|
invert_directions=True,
|
|
)
|
|
|
|
raysampler.eval()
|
|
|
|
cameras, rays = [], []
|
|
|
|
for _ in range(batch_size):
|
|
|
|
R = random_rotations(1)
|
|
T = torch.randn(1, 3)
|
|
focal_length = torch.rand(1, 2) + 0.5
|
|
principal_point = torch.randn(1, 2)
|
|
|
|
camera = PerspectiveCameras(
|
|
focal_length=focal_length,
|
|
principal_point=principal_point,
|
|
R=R,
|
|
T=T,
|
|
)
|
|
|
|
cameras.append(camera)
|
|
rays.append(raysampler(camera))
|
|
|
|
raysampler.precache_rays(cameras, list(range(batch_size)))
|
|
|
|
for cam_index, rays_ in enumerate(rays):
|
|
rays_cached_ = raysampler(
|
|
cameras=cameras[cam_index],
|
|
chunksize=None,
|
|
chunk_idx=0,
|
|
camera_hash=cam_index,
|
|
caching=False,
|
|
)
|
|
|
|
for v, v_cached in zip(rays_, rays_cached_):
|
|
self.assertTrue(torch.allclose(v, v_cached))
|
|
|
|
def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
|
|
"""
|
|
Check that the probabilisitc ray sampler does not crash for various
|
|
settings.
|
|
"""
|
|
|
|
raysampler_grid = NeRFRaysampler(
|
|
min_x=0.0,
|
|
max_x=10.0,
|
|
min_y=0.0,
|
|
max_y=10.0,
|
|
n_pts_per_ray=n_pts_per_ray,
|
|
min_depth=1.0,
|
|
max_depth=10.0,
|
|
n_rays_per_image=12,
|
|
image_width=10,
|
|
image_height=10,
|
|
stratified=False,
|
|
stratified_test=False,
|
|
invert_directions=True,
|
|
)
|
|
|
|
R = random_rotations(batch_size)
|
|
T = torch.randn(batch_size, 3)
|
|
focal_length = torch.rand(batch_size, 2) + 0.5
|
|
principal_point = torch.randn(batch_size, 2)
|
|
camera = PerspectiveCameras(
|
|
focal_length=focal_length,
|
|
principal_point=principal_point,
|
|
R=R,
|
|
T=T,
|
|
)
|
|
|
|
raysampler_grid.eval()
|
|
|
|
ray_bundle = raysampler_grid(cameras=camera)
|
|
|
|
ray_weights = torch.rand_like(ray_bundle.lengths)
|
|
|
|
# Just check that we dont crash for all possible settings.
|
|
for stratified_test in (True, False):
|
|
for stratified in (True, False):
|
|
raysampler_prob = ProbabilisticRaysampler(
|
|
n_pts_per_ray=n_pts_per_ray,
|
|
stratified=stratified,
|
|
stratified_test=stratified_test,
|
|
add_input_samples=True,
|
|
)
|
|
for mode in ("train", "eval"):
|
|
getattr(raysampler_prob, mode)()
|
|
for _ in range(10):
|
|
raysampler_prob(ray_bundle, ray_weights)
|