mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-25 16:50:36 +08:00
Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/384 Test Plan: `test_meshes` and `test_points` Reviewed By: gkioxari Differential Revision: D24730524 Pulled By: nikhilaravi fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
This commit is contained in:
committed by
Facebook GitHub Bot
parent
19340462e4
commit
569e5229a9
@@ -325,6 +325,13 @@ class Meshes(object):
|
||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||
if self._N > 0:
|
||||
self.device = self._verts_list[0].device
|
||||
if not (
|
||||
all(v.device == self.device for v in verts)
|
||||
and all(f.device == self.device for f in faces)
|
||||
):
|
||||
raise ValueError(
|
||||
"All Verts and Faces tensors should be on same device."
|
||||
)
|
||||
self._num_verts_per_mesh = torch.tensor(
|
||||
[len(v) for v in self._verts_list], device=self.device
|
||||
)
|
||||
@@ -341,7 +348,6 @@ class Meshes(object):
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if (len(self._num_verts_per_mesh.unique()) == 1) and (
|
||||
len(self._num_faces_per_mesh.unique()) == 1
|
||||
):
|
||||
@@ -355,6 +361,10 @@ class Meshes(object):
|
||||
self._N = self._verts_padded.shape[0]
|
||||
self._V = self._verts_padded.shape[1]
|
||||
|
||||
if verts.device != faces.device:
|
||||
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
|
||||
raise ValueError(msg.format(verts.device, faces.device))
|
||||
|
||||
self.device = self._verts_padded.device
|
||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||
if self._N > 0:
|
||||
|
||||
Reference in New Issue
Block a user