mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated. Reviewed By: bottler, kjchalup Differential Revision: D39262999 fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
|
from tests.common_testing import TestCaseMixin
|
|
|
|
|
|
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
|
def test_simple(self):
|
|
length = 15
|
|
n_pts_per_ray = 10
|
|
|
|
for add_input_samples in [False, True]:
|
|
ray_point_refiner = RayPointRefiner(
|
|
n_pts_per_ray=n_pts_per_ray,
|
|
random_sampling=False,
|
|
add_input_samples=add_input_samples,
|
|
)
|
|
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
|
bundle = ImplicitronRayBundle(
|
|
lengths=lengths,
|
|
origins=None,
|
|
directions=None,
|
|
xys=None,
|
|
camera_ids=None,
|
|
camera_counts=None,
|
|
)
|
|
weights = torch.ones(3, 25, length)
|
|
refined = ray_point_refiner(bundle, weights)
|
|
|
|
self.assertIsNone(refined.directions)
|
|
self.assertIsNone(refined.origins)
|
|
self.assertIsNone(refined.xys)
|
|
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
|
|
expected = expected.expand(3, 25, n_pts_per_ray)
|
|
if add_input_samples:
|
|
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
|
|
else:
|
|
full_expected = expected
|
|
self.assertClose(refined.lengths, full_expected)
|
|
|
|
ray_point_refiner_random = RayPointRefiner(
|
|
n_pts_per_ray=n_pts_per_ray,
|
|
random_sampling=True,
|
|
add_input_samples=add_input_samples,
|
|
)
|
|
refined_random = ray_point_refiner_random(bundle, weights)
|
|
lengths_random = refined_random.lengths
|
|
self.assertEqual(lengths_random.shape, full_expected.shape)
|
|
if not add_input_samples:
|
|
self.assertGreater(lengths_random.min().item(), 0.5)
|
|
self.assertLess(lengths_random.max().item(), length - 1.5)
|
|
|
|
# Check sorted
|
|
self.assertTrue(
|
|
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
|
|
)
|