New raysamplers

Summary: New MultinomialRaysampler succeeds GridRaysampler bringing masking and subsampling. Correspondingly, NDCMultinomialRaysampler succeeds NDCGridRaysampler.

Reviewed By: nikhilaravi, shapovalov

Differential Revision: D33256897

fbshipit-source-id: cd80ec6f35b110d1d20a75c62f4e889ba8fa5d45
This commit is contained in:
Jeremy Reizenstein
2022-01-24 10:51:03 -08:00
committed by Facebook GitHub Bot
parent 174738c33e
commit 3eb4233844
7 changed files with 412 additions and 61 deletions

View File

@@ -10,9 +10,9 @@ from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
GridRaysampler,
MonteCarloRaysampler,
NDCGridRaysampler,
MultinomialRaysampler,
NDCMultinomialRaysampler,
OrthographicCameras,
PerspectiveCameras,
)
@@ -21,7 +21,11 @@ from test_raysampling import TestRaysampling
def bm_raysampling() -> None:
case_grid = {
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
"raysampler_type": [
MultinomialRaysampler,
NDCMultinomialRaysampler,
MonteCarloRaysampler,
],
"camera_type": [
PerspectiveCameras,
OrthographicCameras,

View File

@@ -6,6 +6,7 @@
import os
import unittest
from numbers import Real
from pathlib import Path
from typing import Callable, Optional, Union
@@ -190,3 +191,13 @@ class TestCaseMixin(unittest.TestCase):
if msg is not None:
self.fail(f"{msg} {err}")
self.fail(err)
def assertConstant(self, input: TensorOrArray, value: Real) -> None:
"""
Asserts input is entirely filled with value.
Args:
input: tensor or array
"""
self.assertEqual(input.min(), value)
self.assertEqual(input.max(), value)

View File

@@ -5,17 +5,27 @@
# LICENSE file in the root directory of this source tree.
import unittest
from typing import Callable
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import eyes
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from pytorch3d.renderer import (
MonteCarloRaysampler,
MultinomialRaysampler,
NDCGridRaysampler,
NDCMultinomialRaysampler,
)
from pytorch3d.renderer.cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from pytorch3d.renderer.implicit.raysampling import (
_jiggle_within_stratas,
_safe_multinomial,
)
from pytorch3d.renderer.implicit.utils import (
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
@@ -93,14 +103,16 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
@staticmethod
def raysampler(
raysampler_type=GridRaysampler,
camera_type=PerspectiveCameras,
n_pts_per_ray=10,
batch_size=1,
image_width=10,
image_height=20,
):
raysampler_type,
camera_type,
n_pts_per_ray: int,
batch_size: int,
image_width: int,
image_height: int,
) -> Callable[[], None]:
"""
Used for benchmarks.
"""
device = torch.device("cuda")
# init raysamplers
@@ -120,7 +132,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
# init a batch of random cameras
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
def run_raysampler():
def run_raysampler() -> None:
raysampler(cameras=cameras)
torch.cuda.synchronize()
@@ -128,7 +140,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
@staticmethod
def init_raysampler(
raysampler_type=GridRaysampler,
raysampler_type,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
@@ -149,7 +161,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
"max_depth": max_depth,
}
if issubclass(raysampler_type, GridRaysampler):
if issubclass(raysampler_type, MultinomialRaysampler):
raysampler_params.update(
{"image_width": image_width, "image_height": image_height}
)
@@ -158,7 +170,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
else:
raise ValueError(str(raysampler_type))
if issubclass(raysampler_type, NDCGridRaysampler):
if issubclass(raysampler_type, NDCMultinomialRaysampler):
# NDCGridRaysampler does not use min/max_x/y
for k in ("min_x", "max_x", "min_y", "max_y"):
del raysampler_params[k]
@@ -191,8 +203,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
for raysampler_type in (
MonteCarloRaysampler,
GridRaysampler,
NDCGridRaysampler,
MultinomialRaysampler,
NDCMultinomialRaysampler,
):
raysampler = TestRaysampling.init_raysampler(
@@ -208,7 +220,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
n_pts_per_ray=n_pts_per_ray,
)
if issubclass(raysampler_type, NDCGridRaysampler):
if issubclass(raysampler_type, NDCMultinomialRaysampler):
# adjust the gt bounds for NDCGridRaysampler
if image_width >= image_height:
range_x = image_width / image_height
@@ -297,7 +309,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
Checks the shapes of raysampler outputs.
"""
if isinstance(raysampler, GridRaysampler):
if isinstance(raysampler, MultinomialRaysampler):
spatial_size = [image_height, image_width]
elif isinstance(raysampler, MonteCarloRaysampler):
spatial_size = [image_height * image_width]
@@ -386,7 +398,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
# check that projected world points' xy coordinates
# range correctly between [minx/y, max/y]
if isinstance(raysampler, GridRaysampler):
if isinstance(raysampler, MultinomialRaysampler):
# get the expected coordinates along each grid axis
ys, xs = [
torch.linspace(
@@ -518,3 +530,51 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
)
state = module1.state_dict()
module2.load_state_dict(state)
def test_jiggle(self):
# random data which is in ascending order along the last dimension
scale = 180
data = scale * torch.cumsum(torch.rand(8, 3, 4, 20), dim=-1)
out = _jiggle_within_stratas(data)
self.assertTupleEqual(out.shape, data.shape)
# Check `out` is in ascending order
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
self.assertConstant(out[..., :-1] < data[..., 1:], True)
self.assertConstant(data[..., :-1] < out[..., 1:], True)
jiggles = out - data
# jiggles is random between -scale/2 and scale/2
self.assertLess(jiggles.min(), -0.4 * scale)
self.assertGreater(jiggles.min(), -0.5 * scale)
self.assertGreater(jiggles.max(), 0.4 * scale)
self.assertLess(jiggles.max(), 0.5 * scale)
def test_safe_multinomial(self):
mask = [
[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
]
tmask = torch.tensor(mask, dtype=torch.float32)
for _ in range(5):
random_scalar = torch.rand(1)
samples = _safe_multinomial(tmask * random_scalar, 3)
self.assertTupleEqual(samples.shape, (4, 3))
# samples[0] is exactly determined
self.assertConstant(samples[0], 0)
self.assertGreaterEqual(samples[1].min(), 0)
self.assertLessEqual(samples[1].max(), 1)
# samples[2] is exactly determined
self.assertSetEqual(set(samples[2].tolist()), {0, 1, 2})
# samples[3] has enough sources, so must contain 3 distinct values.
self.assertLessEqual(samples[3].max(), 3)
self.assertEqual(len(set(samples[3].tolist())), 3)