mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
ImplicitronRayBundle
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated. Reviewed By: bottler, kjchalup Differential Revision: D39262999 fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ae863f301
commit
ad8907d373
@@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
||||
from pytorch3d.renderer import RayBundle
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@@ -24,7 +24,14 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
||||
add_input_samples=add_input_samples,
|
||||
)
|
||||
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
||||
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
|
||||
bundle = ImplicitronRayBundle(
|
||||
lengths=lengths,
|
||||
origins=None,
|
||||
directions=None,
|
||||
xys=None,
|
||||
camera_ids=None,
|
||||
camera_counts=None,
|
||||
)
|
||||
weights = torch.ones(3, 25, length)
|
||||
refined = ray_point_refiner(bundle, weights)
|
||||
|
||||
|
||||
@@ -13,8 +13,10 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network
|
||||
SRNImplicitFunction,
|
||||
SRNPixelGenerator,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
||||
from pytorch3d.renderer import PerspectiveCameras
|
||||
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
_BATCH_SIZE: int = 3
|
||||
@@ -31,12 +33,17 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
||||
def test_pixel_generator(self):
|
||||
SRNPixelGenerator()
|
||||
|
||||
def _get_bundle(self, *, device) -> RayBundle:
|
||||
def _get_bundle(self, *, device) -> ImplicitronRayBundle:
|
||||
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
|
||||
bundle = RayBundle(
|
||||
lengths=lengths, origins=origins, directions=directions, xys=None
|
||||
bundle = ImplicitronRayBundle(
|
||||
lengths=lengths,
|
||||
origins=origins,
|
||||
directions=directions,
|
||||
xys=None,
|
||||
camera_ids=None,
|
||||
camera_counts=None,
|
||||
)
|
||||
return bundle
|
||||
|
||||
|
||||
Reference in New Issue
Block a user