mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
rasterizer.to without cameras
Summary: As reported in https://github.com/facebookresearch/pytorch3d/pull/1100, a rasterizer couldn't be moved if it was missing the optional cameras member. Fix that. This matters because the renderer.to calls rasterizer.to, so this to() could be called even by a user who never sets a cameras member. Reviewed By: nikhilaravi Differential Revision: D34643841 fbshipit-source-id: 7e26e32e8bc585eb1ee533052754a7b59bc7467a
This commit is contained in:
parent
4a1f176054
commit
c371a9a6cc
@ -110,7 +110,8 @@ class MeshRasterizer(nn.Module):
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
return self
|
||||
|
||||
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||
|
@ -115,7 +115,8 @@ class PointsRasterizer(nn.Module):
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
if self.cameras is not None:
|
||||
self.cameras = self.cameras.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, point_clouds, **kwargs) -> PointFragments:
|
||||
|
@ -134,6 +134,12 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(image, image_ref))
|
||||
|
||||
def test_simple_to(self):
|
||||
# Check that to() works without a cameras object.
|
||||
device = torch.device("cuda:0")
|
||||
rasterizer = MeshRasterizer()
|
||||
rasterizer.to(device)
|
||||
|
||||
|
||||
class TestPointRasterizer(unittest.TestCase):
|
||||
def test_simple_sphere(self):
|
||||
@ -203,3 +209,9 @@ class TestPointRasterizer(unittest.TestCase):
|
||||
image[image >= 0] = 1.0
|
||||
image[image < 0] = 0.0
|
||||
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
|
||||
|
||||
def test_simple_to(self):
|
||||
# Check that to() works without a cameras object.
|
||||
device = torch.device("cuda:0")
|
||||
rasterizer = PointsRasterizer()
|
||||
rasterizer.to(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user