diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index bd73f16d..f5c249b3 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -110,7 +110,8 @@ class MeshRasterizer(nn.Module): def to(self, device): # Manually move to device cameras as it is not a subclass of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) return self def transform(self, meshes_world, **kwargs) -> torch.Tensor: diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index cd1cacfc..73233ac7 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -115,7 +115,8 @@ class PointsRasterizer(nn.Module): def to(self, device): # Manually move to device cameras as it is not a subclass of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) return self def forward(self, point_clouds, **kwargs) -> PointFragments: diff --git a/tests/test_rasterizer.py b/tests/test_rasterizer.py index 82e12a97..a1f99365 100644 --- a/tests/test_rasterizer.py +++ b/tests/test_rasterizer.py @@ -134,6 +134,12 @@ class TestMeshRasterizer(unittest.TestCase): self.assertTrue(torch.allclose(image, image_ref)) + def test_simple_to(self): + # Check that to() works without a cameras object. + device = torch.device("cuda:0") + rasterizer = MeshRasterizer() + rasterizer.to(device) + class TestPointRasterizer(unittest.TestCase): def test_simple_sphere(self): @@ -203,3 +209,9 @@ class TestPointRasterizer(unittest.TestCase): image[image >= 0] = 1.0 image[image < 0] = 0.0 self.assertTrue(torch.allclose(image, image_ref[..., 0])) + + def test_simple_to(self): + # Check that to() works without a cameras object. + device = torch.device("cuda:0") + rasterizer = PointsRasterizer() + rasterizer.to(device)