mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-18 21:30:35 +08:00
Heterogeneous raysampling -> RayBundleHeterogeneous
Summary: Added heterogeneous raysampling to pytorch3d raysampler, different cameras are sampled different number of times. It now returns RayBundle if heterogeneous raysampling is off and new RayBundleHeterogeneous (with added fields `camera_ids` and `camera_counts`). Heterogeneous raysampling is on if `n_rays_total` is not None. Reviewed By: bottler Differential Revision: D39542222 fbshipit-source-id: d3d88d822ec7696e856007c088dc36a1cfa8c625
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9a0f9ae572
commit
6ae863f301
@@ -152,6 +152,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
n_pts_per_ray=10,
|
||||
n_rays_total=None,
|
||||
n_rays_per_image=None,
|
||||
):
|
||||
raysampler_params = {
|
||||
"min_x": min_x,
|
||||
@@ -161,6 +163,8 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
"n_pts_per_ray": n_pts_per_ray,
|
||||
"min_depth": min_depth,
|
||||
"max_depth": max_depth,
|
||||
"n_rays_total": n_rays_total,
|
||||
"n_rays_per_image": n_rays_per_image,
|
||||
}
|
||||
|
||||
if issubclass(raysampler_type, MultinomialRaysampler):
|
||||
@@ -168,7 +172,11 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
{"image_width": image_width, "image_height": image_height}
|
||||
)
|
||||
elif issubclass(raysampler_type, MonteCarloRaysampler):
|
||||
raysampler_params["n_rays_per_image"] = image_width * image_height
|
||||
raysampler_params["n_rays_per_image"] = (
|
||||
image_width * image_height
|
||||
if (n_rays_total is None) and (n_rays_per_image is None)
|
||||
else n_rays_per_image
|
||||
)
|
||||
else:
|
||||
raise ValueError(str(raysampler_type))
|
||||
|
||||
@@ -580,3 +588,55 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
# samples[3] has enough sources, so must contain 3 distinct values.
|
||||
self.assertLessEqual(samples[3].max(), 3)
|
||||
self.assertEqual(len(set(samples[3].tolist())), 3)
|
||||
|
||||
def test_heterogeneous_sampling(self, batch_size=8):
|
||||
"""
|
||||
Test that the output of heterogeneous sampling has the first dimension equal
|
||||
to n_rays_total and second to 1 and that ray_bundle elements from different
|
||||
sampled cameras are different and equal for same sampled cameras.
|
||||
"""
|
||||
cameras = init_random_cameras(PerspectiveCameras, batch_size, random_z=True)
|
||||
for n_rays_total in [2, 3, 17, 21, 32]:
|
||||
for cls in (MultinomialRaysampler, MonteCarloRaysampler):
|
||||
with self.subTest(cls.__name__ + ", n_rays_total=" + str(n_rays_total)):
|
||||
raysampler = self.init_raysampler(
|
||||
cls, n_rays_total=n_rays_total, n_rays_per_image=None
|
||||
)
|
||||
ray_bundle = raysampler(cameras)
|
||||
|
||||
# test weather they are of the correct shape
|
||||
for attr in ("origins", "directions", "lengths", "xys"):
|
||||
tensor = getattr(ray_bundle, attr)
|
||||
assert tensor.shape[:2] == torch.Size(
|
||||
(n_rays_total, 1)
|
||||
), tensor.shape
|
||||
|
||||
# if two camera ids are same than origins should also be the same
|
||||
# directions and xys are always different and lengths equal
|
||||
for i1, (origin1, dir1, len1, id1) in enumerate(
|
||||
zip(
|
||||
ray_bundle.origins,
|
||||
ray_bundle.directions,
|
||||
ray_bundle.lengths,
|
||||
torch.repeat_interleave(
|
||||
ray_bundle.camera_ids, ray_bundle.camera_counts
|
||||
),
|
||||
)
|
||||
):
|
||||
for i2, (origin2, dir2, len2, id2) in enumerate(
|
||||
zip(
|
||||
ray_bundle.origins,
|
||||
ray_bundle.directions,
|
||||
ray_bundle.lengths,
|
||||
torch.repeat_interleave(
|
||||
ray_bundle.camera_ids, ray_bundle.camera_counts
|
||||
),
|
||||
)
|
||||
):
|
||||
if i1 == i2:
|
||||
continue
|
||||
assert torch.allclose(
|
||||
origin1, origin2, rtol=1e-4, atol=1e-4
|
||||
) == (id1 == id2), (origin1, origin2, id1, id2)
|
||||
assert not torch.allclose(dir1, dir2), (dir1, dir2)
|
||||
self.assertClose(len1, len2), (len1, len2)
|
||||
|
||||
Reference in New Issue
Block a user