diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 9af8611e..d3345296 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -140,10 +140,6 @@ def _pad_texture_maps( # we can allow the input textures to be any texture # type which is an instance of the base class. class TexturesBase: - def __init__(self): - self._N = 0 - self.valid = None - def isempty(self): if self._N is not None and self.valid is not None: return self._N == 0 or self.valid.eq(False).all() @@ -159,6 +155,7 @@ class TexturesBase: setattr(self, k, v) if torch.is_tensor(v) and v.device != device: setattr(self, k, v.to(device)) + self.device = device return self def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]: @@ -634,7 +631,6 @@ class TexturesUV(TexturesBase): [0.0005, 0.9995] or if the second is outside the interval [0.005, 0.995]. """ - super().__init__() self.padding_mode = padding_mode self.align_corners = align_corners if isinstance(faces_uvs, (list, tuple)): diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 9a6527d0..445264c0 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1197,7 +1197,7 @@ class Meshes(object): if torch.is_tensor(v): setattr(other, k, v.to(device)) if self.textures is not None: - other.textures = self.textures.to(device) + other.textures = other.textures.to(device) return other def cpu(self): diff --git a/tests/test_texturing.py b/tests/test_texturing.py index d477f713..a9043881 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -793,9 +793,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): ) device = torch.device("cuda:0") tex = tex.to(device) - self.assertTrue(tex._faces_uvs_padded.device == device) - self.assertTrue(tex._verts_uvs_padded.device == device) - self.assertTrue(tex._maps_padded.device == device) + self.assertEqual(tex._faces_uvs_padded.device, device) + self.assertEqual(tex._verts_uvs_padded.device, device) + self.assertEqual(tex._maps_padded.device, device) + + def test_mesh_to(self): + tex_cpu = TexturesUV( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.randint(size=(5, 10, 3), high=15), + verts_uvs=torch.rand(size=(5, 15, 2)), + ) + verts = torch.rand(size=(5, 15, 3)) + faces = torch.randint(size=(5, 10, 3), high=15) + mesh_cpu = Meshes(faces=faces, verts=verts, textures=tex_cpu) + cpu = torch.device("cpu") + device = torch.device("cuda:0") + tex = mesh_cpu.to(device).textures + self.assertEqual(tex._faces_uvs_padded.device, device) + self.assertEqual(tex._verts_uvs_padded.device, device) + self.assertEqual(tex._maps_padded.device, device) + self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu) + + self.assertEqual(tex_cpu.device, cpu) + self.assertEqual(tex.device, device) def test_getitem(self): N = 5