Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)

Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/384

Test Plan: `test_meshes` and `test_points`

Reviewed By: gkioxari

Differential Revision: D24730524

Pulled By: nikhilaravi

fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
This commit is contained in:
Evgeniy Zheltonozhskiy
2020-12-14 16:17:23 -08:00
committed by Facebook GitHub Bot
parent 19340462e4
commit 569e5229a9
4 changed files with 85 additions and 2 deletions

View File

@@ -325,6 +325,13 @@ class Meshes(object):
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:
self.device = self._verts_list[0].device
if not (
all(v.device == self.device for v in verts)
and all(f.device == self.device for f in faces)
):
raise ValueError(
"All Verts and Faces tensors should be on same device."
)
self._num_verts_per_mesh = torch.tensor(
[len(v) for v in self._verts_list], device=self.device
)
@@ -341,7 +348,6 @@ class Meshes(object):
dtype=torch.bool,
device=self.device,
)
if (len(self._num_verts_per_mesh.unique()) == 1) and (
len(self._num_faces_per_mesh.unique()) == 1
):
@@ -355,6 +361,10 @@ class Meshes(object):
self._N = self._verts_padded.shape[0]
self._V = self._verts_padded.shape[1]
if verts.device != faces.device:
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
raise ValueError(msg.format(verts.device, faces.device))
self.device = self._verts_padded.device
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
if self._N > 0:

View File

@@ -180,11 +180,13 @@ class Pointclouds(object):
self._num_points_per_cloud = []
if self._N > 0:
self.device = self._points_list[0].device
for p in self._points_list:
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
raise ValueError("Clouds in list must be of shape Px3 or empty")
if p.device != self.device:
raise ValueError("All points must be on the same device")
self.device = self._points_list[0].device
num_points_per_cloud = torch.tensor(
[len(p) for p in self._points_list], device=self.device
)
@@ -261,6 +263,10 @@ class Pointclouds(object):
raise ValueError(
"A cloud has mismatched numbers of points and inputs"
)
if d.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
if p > 0:
if d.dim() != 2:
raise ValueError(
@@ -283,6 +289,10 @@ class Pointclouds(object):
"Inputs tensor must have the right maximum \
number of points in each cloud."
)
if aux_input.device != self.device:
raise ValueError(
"All auxillary inputs must be on the same device as the points."
)
aux_input_C = aux_input.shape[2]
return None, aux_input, aux_input_C
else: