textures device consistency

Summary: Ensure that `mesh2 = mesh.to(device)` doesn't change the device of `mesh.textures`.

Reviewed By: nikhilaravi

Differential Revision: D25978610

fbshipit-source-id: 0558cd62132965d8693ebeea05e42b8c1d16cfbf
This commit is contained in:
Jeremy Reizenstein 2021-01-25 06:08:09 -08:00 committed by Facebook GitHub Bot
parent e58a730e6a
commit d173a2f8da
3 changed files with 25 additions and 9 deletions

View File

@ -140,10 +140,6 @@ def _pad_texture_maps(
# we can allow the input textures to be any texture
# type which is an instance of the base class.
class TexturesBase:
def __init__(self):
self._N = 0
self.valid = None
def isempty(self):
if self._N is not None and self.valid is not None:
return self._N == 0 or self.valid.eq(False).all()
@ -159,6 +155,7 @@ class TexturesBase:
setattr(self, k, v)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
self.device = device
return self
def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]:
@ -634,7 +631,6 @@ class TexturesUV(TexturesBase):
[0.0005, 0.9995] or if the second is outside the interval
[0.005, 0.995].
"""
super().__init__()
self.padding_mode = padding_mode
self.align_corners = align_corners
if isinstance(faces_uvs, (list, tuple)):

View File

@ -1197,7 +1197,7 @@ class Meshes(object):
if torch.is_tensor(v):
setattr(other, k, v.to(device))
if self.textures is not None:
other.textures = self.textures.to(device)
other.textures = other.textures.to(device)
return other
def cpu(self):

View File

@ -793,9 +793,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertTrue(tex._faces_uvs_padded.device == device)
self.assertTrue(tex._verts_uvs_padded.device == device)
self.assertTrue(tex._maps_padded.device == device)
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
def test_mesh_to(self):
tex_cpu = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
verts = torch.rand(size=(5, 15, 3))
faces = torch.randint(size=(5, 10, 3), high=15)
mesh_cpu = Meshes(faces=faces, verts=verts, textures=tex_cpu)
cpu = torch.device("cpu")
device = torch.device("cuda:0")
tex = mesh_cpu.to(device).textures
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
self.assertEqual(tex_cpu.device, cpu)
self.assertEqual(tex.device, device)
def test_getitem(self):
N = 5