mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
NeRF Raysampler
Summary: Implements the NeRF raysampler. Reviewed By: nikhilaravi Differential Revision: D25684403 fbshipit-source-id: 616a60f047c79479f60a6a75d214f87cbfb06d28
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fba419b7f7
commit
7cbda3ec17
122
projects/nerf/tests/test_raysampler.py
Normal file
122
projects/nerf/tests/test_raysampler.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
|
||||
|
||||
class TestRaysampler(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_raysampler_caching(self, batch_size=10):
|
||||
"""
|
||||
Tests the consistency of the NeRF raysampler caching.
|
||||
"""
|
||||
|
||||
raysampler = NeRFRaysampler(
|
||||
min_x=0.0,
|
||||
max_x=10.0,
|
||||
min_y=0.0,
|
||||
max_y=10.0,
|
||||
n_pts_per_ray=10,
|
||||
min_depth=0.1,
|
||||
max_depth=10.0,
|
||||
n_rays_per_image=12,
|
||||
image_width=10,
|
||||
image_height=10,
|
||||
stratified=False,
|
||||
stratified_test=False,
|
||||
invert_directions=True,
|
||||
)
|
||||
|
||||
raysampler.eval()
|
||||
|
||||
cameras, rays = [], []
|
||||
|
||||
for _ in range(batch_size):
|
||||
|
||||
R = random_rotations(1)
|
||||
T = torch.randn(1, 3)
|
||||
focal_length = torch.rand(1, 2) + 0.5
|
||||
principal_point = torch.randn(1, 2)
|
||||
|
||||
camera = PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
)
|
||||
|
||||
cameras.append(camera)
|
||||
rays.append(raysampler(camera))
|
||||
|
||||
raysampler.precache_rays(cameras, list(range(batch_size)))
|
||||
|
||||
for cam_index, rays_ in enumerate(rays):
|
||||
rays_cached_ = raysampler(
|
||||
cameras=cameras[cam_index],
|
||||
chunksize=None,
|
||||
chunk_idx=0,
|
||||
camera_hash=cam_index,
|
||||
caching=False,
|
||||
)
|
||||
|
||||
for v, v_cached in zip(rays_, rays_cached_):
|
||||
self.assertTrue(torch.allclose(v, v_cached))
|
||||
|
||||
def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60):
|
||||
"""
|
||||
Check that the probabilisitc ray sampler does not crash for various
|
||||
settings.
|
||||
"""
|
||||
|
||||
raysampler_grid = NeRFRaysampler(
|
||||
min_x=0.0,
|
||||
max_x=10.0,
|
||||
min_y=0.0,
|
||||
max_y=10.0,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
n_rays_per_image=12,
|
||||
image_width=10,
|
||||
image_height=10,
|
||||
stratified=False,
|
||||
stratified_test=False,
|
||||
invert_directions=True,
|
||||
)
|
||||
|
||||
R = random_rotations(batch_size)
|
||||
T = torch.randn(batch_size, 3)
|
||||
focal_length = torch.rand(batch_size, 2) + 0.5
|
||||
principal_point = torch.randn(batch_size, 2)
|
||||
camera = PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
)
|
||||
|
||||
raysampler_grid.eval()
|
||||
|
||||
ray_bundle = raysampler_grid(cameras=camera)
|
||||
|
||||
ray_weights = torch.rand_like(ray_bundle.lengths)
|
||||
|
||||
# Just check that we dont crash for all possible settings.
|
||||
for stratified_test in (True, False):
|
||||
for stratified in (True, False):
|
||||
raysampler_prob = ProbabilisticRaysampler(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
stratified=stratified,
|
||||
stratified_test=stratified_test,
|
||||
add_input_samples=True,
|
||||
)
|
||||
for mode in ("train", "eval"):
|
||||
getattr(raysampler_prob, mode)()
|
||||
for _ in range(10):
|
||||
raysampler_prob(ray_bundle, ray_weights)
|
||||
Reference in New Issue
Block a user