mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 13:50:35 +08:00
New raysamplers
Summary: New MultinomialRaysampler succeeds GridRaysampler bringing masking and subsampling. Correspondingly, NDCMultinomialRaysampler succeeds NDCGridRaysampler. Reviewed By: nikhilaravi, shapovalov Differential Revision: D33256897 fbshipit-source-id: cd80ec6f35b110d1d20a75c62f4e889ba8fa5d45
This commit is contained in:
committed by
Facebook GitHub Bot
parent
174738c33e
commit
3eb4233844
@@ -10,9 +10,9 @@ from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
GridRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
@@ -21,7 +21,11 @@ from test_raysampling import TestRaysampling
|
||||
|
||||
def bm_raysampling() -> None:
|
||||
case_grid = {
|
||||
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
|
||||
"raysampler_type": [
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
],
|
||||
"camera_type": [
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from numbers import Real
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -190,3 +191,13 @@ class TestCaseMixin(unittest.TestCase):
|
||||
if msg is not None:
|
||||
self.fail(f"{msg} {err}")
|
||||
self.fail(err)
|
||||
|
||||
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
|
||||
"""
|
||||
Asserts input is entirely filled with value.
|
||||
|
||||
Args:
|
||||
input: tensor or array
|
||||
"""
|
||||
self.assertEqual(input.min(), value)
|
||||
self.assertEqual(input.max(), value)
|
||||
|
||||
@@ -5,17 +5,27 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from pytorch3d.renderer import (
|
||||
MonteCarloRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCGridRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.renderer.implicit.raysampling import (
|
||||
_jiggle_within_stratas,
|
||||
_safe_multinomial,
|
||||
)
|
||||
from pytorch3d.renderer.implicit.utils import (
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
@@ -93,14 +103,16 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
camera_type=PerspectiveCameras,
|
||||
n_pts_per_ray=10,
|
||||
batch_size=1,
|
||||
image_width=10,
|
||||
image_height=20,
|
||||
):
|
||||
|
||||
raysampler_type,
|
||||
camera_type,
|
||||
n_pts_per_ray: int,
|
||||
batch_size: int,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
Used for benchmarks.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# init raysamplers
|
||||
@@ -120,7 +132,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
# init a batch of random cameras
|
||||
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
|
||||
|
||||
def run_raysampler():
|
||||
def run_raysampler() -> None:
|
||||
raysampler(cameras=cameras)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -128,7 +140,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def init_raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
raysampler_type,
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
@@ -149,7 +161,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
"max_depth": max_depth,
|
||||
}
|
||||
|
||||
if issubclass(raysampler_type, GridRaysampler):
|
||||
if issubclass(raysampler_type, MultinomialRaysampler):
|
||||
raysampler_params.update(
|
||||
{"image_width": image_width, "image_height": image_height}
|
||||
)
|
||||
@@ -158,7 +170,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
else:
|
||||
raise ValueError(str(raysampler_type))
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
if issubclass(raysampler_type, NDCMultinomialRaysampler):
|
||||
# NDCGridRaysampler does not use min/max_x/y
|
||||
for k in ("min_x", "max_x", "min_y", "max_y"):
|
||||
del raysampler_params[k]
|
||||
@@ -191,8 +203,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
for raysampler_type in (
|
||||
MonteCarloRaysampler,
|
||||
GridRaysampler,
|
||||
NDCGridRaysampler,
|
||||
MultinomialRaysampler,
|
||||
NDCMultinomialRaysampler,
|
||||
):
|
||||
|
||||
raysampler = TestRaysampling.init_raysampler(
|
||||
@@ -208,7 +220,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
)
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
if issubclass(raysampler_type, NDCMultinomialRaysampler):
|
||||
# adjust the gt bounds for NDCGridRaysampler
|
||||
if image_width >= image_height:
|
||||
range_x = image_width / image_height
|
||||
@@ -297,7 +309,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
Checks the shapes of raysampler outputs.
|
||||
"""
|
||||
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
if isinstance(raysampler, MultinomialRaysampler):
|
||||
spatial_size = [image_height, image_width]
|
||||
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||
spatial_size = [image_height * image_width]
|
||||
@@ -386,7 +398,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# check that projected world points' xy coordinates
|
||||
# range correctly between [minx/y, max/y]
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
if isinstance(raysampler, MultinomialRaysampler):
|
||||
# get the expected coordinates along each grid axis
|
||||
ys, xs = [
|
||||
torch.linspace(
|
||||
@@ -518,3 +530,51 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
state = module1.state_dict()
|
||||
module2.load_state_dict(state)
|
||||
|
||||
def test_jiggle(self):
|
||||
# random data which is in ascending order along the last dimension
|
||||
scale = 180
|
||||
data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1)
|
||||
|
||||
out = _jiggle_within_stratas(data)
|
||||
self.assertTupleEqual(out.shape, data.shape)
|
||||
|
||||
# Check `out` is in ascending order
|
||||
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
|
||||
|
||||
self.assertConstant(out[..., :-1] < data[..., 1:], True)
|
||||
self.assertConstant(data[..., :-1] < out[..., 1:], True)
|
||||
|
||||
jiggles = out - data
|
||||
# jiggles is random between -scale/2 and scale/2
|
||||
self.assertLess(jiggles.min(), -0.4 * scale)
|
||||
self.assertGreater(jiggles.min(), -0.5 * scale)
|
||||
self.assertGreater(jiggles.max(), 0.4 * scale)
|
||||
self.assertLess(jiggles.max(), 0.5 * scale)
|
||||
|
||||
def test_safe_multinomial(self):
|
||||
mask = [
|
||||
[1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 1, 1, 0],
|
||||
]
|
||||
tmask = torch.tensor(mask, dtype=torch.float32)
|
||||
|
||||
for _ in range(5):
|
||||
random_scalar = torch.rand(1)
|
||||
samples = _safe_multinomial(tmask * random_scalar, 3)
|
||||
self.assertTupleEqual(samples.shape, (4, 3))
|
||||
|
||||
# samples[0] is exactly determined
|
||||
self.assertConstant(samples[0], 0)
|
||||
|
||||
self.assertGreaterEqual(samples[1].min(), 0)
|
||||
self.assertLessEqual(samples[1].max(), 1)
|
||||
|
||||
# samples[2] is exactly determined
|
||||
self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2})
|
||||
|
||||
# samples[3] has enough sources, so must contain 3 distinct values.
|
||||
self.assertLessEqual(samples[3].max(), 3)
|
||||
self.assertEqual(len(set(samples[3].tolist())), 3)
|
||||
|
||||
Reference in New Issue
Block a user