diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index e1181b0b..5be111f2 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -325,6 +325,13 @@ class Meshes(object): self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) if self._N > 0: self.device = self._verts_list[0].device + if not ( + all(v.device == self.device for v in verts) + and all(f.device == self.device for f in faces) + ): + raise ValueError( + "All Verts and Faces tensors should be on same device." + ) self._num_verts_per_mesh = torch.tensor( [len(v) for v in self._verts_list], device=self.device ) @@ -341,7 +348,6 @@ class Meshes(object): dtype=torch.bool, device=self.device, ) - if (len(self._num_verts_per_mesh.unique()) == 1) and ( len(self._num_faces_per_mesh.unique()) == 1 ): @@ -355,6 +361,10 @@ class Meshes(object): self._N = self._verts_padded.shape[0] self._V = self._verts_padded.shape[1] + if verts.device != faces.device: + msg = "Verts and Faces tensors should be on same device. \n Got {} and {}." + raise ValueError(msg.format(verts.device, faces.device)) + self.device = self._verts_padded.device self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device) if self._N > 0: diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 97fb339d..4b48eb6a 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -180,11 +180,13 @@ class Pointclouds(object): self._num_points_per_cloud = [] if self._N > 0: + self.device = self._points_list[0].device for p in self._points_list: if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3): raise ValueError("Clouds in list must be of shape Px3 or empty") + if p.device != self.device: + raise ValueError("All points must be on the same device") - self.device = self._points_list[0].device num_points_per_cloud = torch.tensor( [len(p) for p in self._points_list], device=self.device ) @@ -261,6 +263,10 @@ class Pointclouds(object): raise ValueError( "A cloud has mismatched numbers of points and inputs" ) + if d.device != self.device: + raise ValueError( + "All auxillary inputs must be on the same device as the points." + ) if p > 0: if d.dim() != 2: raise ValueError( @@ -283,6 +289,10 @@ class Pointclouds(object): "Inputs tensor must have the right maximum \ number of points in each cloud." ) + if aux_input.device != self.device: + raise ValueError( + "All auxillary inputs must be on the same device as the points." + ) aux_input_C = aux_input.shape[2] return None, aux_input, aux_input_C else: diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 7a6446f1..f79ac3c9 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import random import unittest import numpy as np @@ -162,6 +163,29 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): torch.tensor([0, 3, 8], dtype=torch.int64), ) + def test_init_error(self): + # Check if correct errors are raised when verts/faces are on + # different devices + + mesh = TestMeshes.init_mesh(10, 10, 100) + verts_list = mesh.verts_list() # all tensors on cpu + verts_list = [ + v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list + ] + faces_list = mesh.faces_list() + + with self.assertRaises(ValueError) as cm: + Meshes(verts=verts_list, faces=faces_list) + self.assertTrue("same device" in cm.msg) + + verts_padded = mesh.verts_padded() # on cpu + verts_padded = verts_padded.to("cuda:0") + faces_padded = mesh.faces_padded() + + with self.assertRaises(ValueError) as cm: + Meshes(verts=verts_padded, faces=faces_padded) + self.assertTrue("same device" in cm.msg) + def test_simple_random_meshes(self): # Define the test mesh object either as a list or tensor of faces/verts. diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index 701254c1..220044c4 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import random import unittest import numpy as np @@ -126,6 +127,44 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]), ) + def test_init_error(self): + # Check if correct errors are raised when verts/faces are on + # different devices + + clouds = self.init_cloud(10, 100, 5) + points_list = clouds.points_list() # all tensors on cuda:0 + points_list = [ + p.to("cpu") if random.uniform(0, 1) > 0.5 else p for p in points_list + ] + features_list = clouds.features_list() + normals_list = clouds.normals_list() + + with self.assertRaises(ValueError) as cm: + Pointclouds( + points=points_list, features=features_list, normals=normals_list + ) + self.assertTrue("same device" in cm.msg) + + points_list = clouds.points_list() + features_list = [ + f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list + ] + with self.assertRaises(ValueError) as cm: + Pointclouds( + points=points_list, features=features_list, normals=normals_list + ) + self.assertTrue("same device" in cm.msg) + + points_padded = clouds.points_padded() # on cuda:0 + features_padded = clouds.features_padded().to("cpu") + normals_padded = clouds.normals_padded() + + with self.assertRaises(ValueError) as cm: + Pointclouds( + points=points_padded, features=features_padded, normals=normals_padded + ) + self.assertTrue("same device" in cm.msg) + def test_all_constructions(self): public_getters = [ "points_list",