mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
textures device consistency
Summary: Ensure that `mesh2 = mesh.to(device)` doesn't change the device of `mesh.textures`. Reviewed By: nikhilaravi Differential Revision: D25978610 fbshipit-source-id: 0558cd62132965d8693ebeea05e42b8c1d16cfbf
This commit is contained in:
parent
e58a730e6a
commit
d173a2f8da
@ -140,10 +140,6 @@ def _pad_texture_maps(
|
||||
# we can allow the input textures to be any texture
|
||||
# type which is an instance of the base class.
|
||||
class TexturesBase:
|
||||
def __init__(self):
|
||||
self._N = 0
|
||||
self.valid = None
|
||||
|
||||
def isempty(self):
|
||||
if self._N is not None and self.valid is not None:
|
||||
return self._N == 0 or self.valid.eq(False).all()
|
||||
@ -159,6 +155,7 @@ class TexturesBase:
|
||||
setattr(self, k, v)
|
||||
if torch.is_tensor(v) and v.device != device:
|
||||
setattr(self, k, v.to(device))
|
||||
self.device = device
|
||||
return self
|
||||
|
||||
def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]:
|
||||
@ -634,7 +631,6 @@ class TexturesUV(TexturesBase):
|
||||
[0.0005, 0.9995] or if the second is outside the interval
|
||||
[0.005, 0.995].
|
||||
"""
|
||||
super().__init__()
|
||||
self.padding_mode = padding_mode
|
||||
self.align_corners = align_corners
|
||||
if isinstance(faces_uvs, (list, tuple)):
|
||||
|
@ -1197,7 +1197,7 @@ class Meshes(object):
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.to(device))
|
||||
if self.textures is not None:
|
||||
other.textures = self.textures.to(device)
|
||||
other.textures = other.textures.to(device)
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
|
@ -793,9 +793,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
tex = tex.to(device)
|
||||
self.assertTrue(tex._faces_uvs_padded.device == device)
|
||||
self.assertTrue(tex._verts_uvs_padded.device == device)
|
||||
self.assertTrue(tex._maps_padded.device == device)
|
||||
self.assertEqual(tex._faces_uvs_padded.device, device)
|
||||
self.assertEqual(tex._verts_uvs_padded.device, device)
|
||||
self.assertEqual(tex._maps_padded.device, device)
|
||||
|
||||
def test_mesh_to(self):
|
||||
tex_cpu = TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
verts = torch.rand(size=(5, 15, 3))
|
||||
faces = torch.randint(size=(5, 10, 3), high=15)
|
||||
mesh_cpu = Meshes(faces=faces, verts=verts, textures=tex_cpu)
|
||||
cpu = torch.device("cpu")
|
||||
device = torch.device("cuda:0")
|
||||
tex = mesh_cpu.to(device).textures
|
||||
self.assertEqual(tex._faces_uvs_padded.device, device)
|
||||
self.assertEqual(tex._verts_uvs_padded.device, device)
|
||||
self.assertEqual(tex._maps_padded.device, device)
|
||||
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
|
||||
|
||||
self.assertEqual(tex_cpu.device, cpu)
|
||||
self.assertEqual(tex.device, device)
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user