ImplicitronRayBundle

Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated.

Reviewed By: bottler, kjchalup

Differential Revision: D39262999

fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
This commit is contained in:
Darijan Gudelj
2022-10-03 08:36:47 -07:00
committed by Facebook GitHub Bot
parent 6ae863f301
commit ad8907d373
18 changed files with 259 additions and 100 deletions

View File

@@ -8,7 +8,7 @@ import unittest
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.renderer import RayBundle
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from tests.common_testing import TestCaseMixin
@@ -24,7 +24,14 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
add_input_samples=add_input_samples,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
bundle = ImplicitronRayBundle(
lengths=lengths,
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights)

View File

@@ -13,8 +13,10 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras, RayBundle
from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
@@ -31,12 +33,17 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle:
def _get_bundle(self, *, device) -> ImplicitronRayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None
bundle = ImplicitronRayBundle(
lengths=lengths,
origins=origins,
directions=directions,
xys=None,
camera_ids=None,
camera_counts=None,
)
return bundle