mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
textures device consistency
Summary: Ensure that `mesh2 = mesh.to(device)` doesn't change the device of `mesh.textures`. Reviewed By: nikhilaravi Differential Revision: D25978610 fbshipit-source-id: 0558cd62132965d8693ebeea05e42b8c1d16cfbf
This commit is contained in:
parent
e58a730e6a
commit
d173a2f8da
@ -140,10 +140,6 @@ def _pad_texture_maps(
|
|||||||
# we can allow the input textures to be any texture
|
# we can allow the input textures to be any texture
|
||||||
# type which is an instance of the base class.
|
# type which is an instance of the base class.
|
||||||
class TexturesBase:
|
class TexturesBase:
|
||||||
def __init__(self):
|
|
||||||
self._N = 0
|
|
||||||
self.valid = None
|
|
||||||
|
|
||||||
def isempty(self):
|
def isempty(self):
|
||||||
if self._N is not None and self.valid is not None:
|
if self._N is not None and self.valid is not None:
|
||||||
return self._N == 0 or self.valid.eq(False).all()
|
return self._N == 0 or self.valid.eq(False).all()
|
||||||
@ -159,6 +155,7 @@ class TexturesBase:
|
|||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
if torch.is_tensor(v) and v.device != device:
|
if torch.is_tensor(v) and v.device != device:
|
||||||
setattr(self, k, v.to(device))
|
setattr(self, k, v.to(device))
|
||||||
|
self.device = device
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]:
|
def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]:
|
||||||
@ -634,7 +631,6 @@ class TexturesUV(TexturesBase):
|
|||||||
[0.0005, 0.9995] or if the second is outside the interval
|
[0.0005, 0.9995] or if the second is outside the interval
|
||||||
[0.005, 0.995].
|
[0.005, 0.995].
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
|
||||||
self.padding_mode = padding_mode
|
self.padding_mode = padding_mode
|
||||||
self.align_corners = align_corners
|
self.align_corners = align_corners
|
||||||
if isinstance(faces_uvs, (list, tuple)):
|
if isinstance(faces_uvs, (list, tuple)):
|
||||||
|
@ -1197,7 +1197,7 @@ class Meshes(object):
|
|||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
setattr(other, k, v.to(device))
|
setattr(other, k, v.to(device))
|
||||||
if self.textures is not None:
|
if self.textures is not None:
|
||||||
other.textures = self.textures.to(device)
|
other.textures = other.textures.to(device)
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
|
@ -793,9 +793,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
tex = tex.to(device)
|
tex = tex.to(device)
|
||||||
self.assertTrue(tex._faces_uvs_padded.device == device)
|
self.assertEqual(tex._faces_uvs_padded.device, device)
|
||||||
self.assertTrue(tex._verts_uvs_padded.device == device)
|
self.assertEqual(tex._verts_uvs_padded.device, device)
|
||||||
self.assertTrue(tex._maps_padded.device == device)
|
self.assertEqual(tex._maps_padded.device, device)
|
||||||
|
|
||||||
|
def test_mesh_to(self):
|
||||||
|
tex_cpu = TexturesUV(
|
||||||
|
maps=torch.ones((5, 16, 16, 3)),
|
||||||
|
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
|
||||||
|
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||||
|
)
|
||||||
|
verts = torch.rand(size=(5, 15, 3))
|
||||||
|
faces = torch.randint(size=(5, 10, 3), high=15)
|
||||||
|
mesh_cpu = Meshes(faces=faces, verts=verts, textures=tex_cpu)
|
||||||
|
cpu = torch.device("cpu")
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
tex = mesh_cpu.to(device).textures
|
||||||
|
self.assertEqual(tex._faces_uvs_padded.device, device)
|
||||||
|
self.assertEqual(tex._verts_uvs_padded.device, device)
|
||||||
|
self.assertEqual(tex._maps_padded.device, device)
|
||||||
|
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
|
||||||
|
|
||||||
|
self.assertEqual(tex_cpu.device, cpu)
|
||||||
|
self.assertEqual(tex.device, device)
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
N = 5
|
N = 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user