mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/384 Test Plan: `test_meshes` and `test_points` Reviewed By: gkioxari Differential Revision: D24730524 Pulled By: nikhilaravi fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
This commit is contained in:
parent
19340462e4
commit
569e5229a9
@ -325,6 +325,13 @@ class Meshes(object):
|
||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||
if self._N > 0:
|
||||
self.device = self._verts_list[0].device
|
||||
if not (
|
||||
all(v.device == self.device for v in verts)
|
||||
and all(f.device == self.device for f in faces)
|
||||
):
|
||||
raise ValueError(
|
||||
"All Verts and Faces tensors should be on same device."
|
||||
)
|
||||
self._num_verts_per_mesh = torch.tensor(
|
||||
[len(v) for v in self._verts_list], device=self.device
|
||||
)
|
||||
@ -341,7 +348,6 @@ class Meshes(object):
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if (len(self._num_verts_per_mesh.unique()) == 1) and (
|
||||
len(self._num_faces_per_mesh.unique()) == 1
|
||||
):
|
||||
@ -355,6 +361,10 @@ class Meshes(object):
|
||||
self._N = self._verts_padded.shape[0]
|
||||
self._V = self._verts_padded.shape[1]
|
||||
|
||||
if verts.device != faces.device:
|
||||
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
|
||||
raise ValueError(msg.format(verts.device, faces.device))
|
||||
|
||||
self.device = self._verts_padded.device
|
||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||
if self._N > 0:
|
||||
|
@ -180,11 +180,13 @@ class Pointclouds(object):
|
||||
self._num_points_per_cloud = []
|
||||
|
||||
if self._N > 0:
|
||||
self.device = self._points_list[0].device
|
||||
for p in self._points_list:
|
||||
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
|
||||
raise ValueError("Clouds in list must be of shape Px3 or empty")
|
||||
if p.device != self.device:
|
||||
raise ValueError("All points must be on the same device")
|
||||
|
||||
self.device = self._points_list[0].device
|
||||
num_points_per_cloud = torch.tensor(
|
||||
[len(p) for p in self._points_list], device=self.device
|
||||
)
|
||||
@ -261,6 +263,10 @@ class Pointclouds(object):
|
||||
raise ValueError(
|
||||
"A cloud has mismatched numbers of points and inputs"
|
||||
)
|
||||
if d.device != self.device:
|
||||
raise ValueError(
|
||||
"All auxillary inputs must be on the same device as the points."
|
||||
)
|
||||
if p > 0:
|
||||
if d.dim() != 2:
|
||||
raise ValueError(
|
||||
@ -283,6 +289,10 @@ class Pointclouds(object):
|
||||
"Inputs tensor must have the right maximum \
|
||||
number of points in each cloud."
|
||||
)
|
||||
if aux_input.device != self.device:
|
||||
raise ValueError(
|
||||
"All auxillary inputs must be on the same device as the points."
|
||||
)
|
||||
aux_input_C = aux_input.shape[2]
|
||||
return None, aux_input, aux_input_C
|
||||
else:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -162,6 +163,29 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
torch.tensor([0, 3, 8], dtype=torch.int64),
|
||||
)
|
||||
|
||||
def test_init_error(self):
|
||||
# Check if correct errors are raised when verts/faces are on
|
||||
# different devices
|
||||
|
||||
mesh = TestMeshes.init_mesh(10, 10, 100)
|
||||
verts_list = mesh.verts_list() # all tensors on cpu
|
||||
verts_list = [
|
||||
v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list
|
||||
]
|
||||
faces_list = mesh.faces_list()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Meshes(verts=verts_list, faces=faces_list)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
verts_padded = mesh.verts_padded() # on cpu
|
||||
verts_padded = verts_padded.to("cuda:0")
|
||||
faces_padded = mesh.faces_padded()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Meshes(verts=verts_padded, faces=faces_padded)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
def test_simple_random_meshes(self):
|
||||
|
||||
# Define the test mesh object either as a list or tensor of faces/verts.
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -126,6 +127,44 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
|
||||
)
|
||||
|
||||
def test_init_error(self):
|
||||
# Check if correct errors are raised when verts/faces are on
|
||||
# different devices
|
||||
|
||||
clouds = self.init_cloud(10, 100, 5)
|
||||
points_list = clouds.points_list() # all tensors on cuda:0
|
||||
points_list = [
|
||||
p.to("cpu") if random.uniform(0, 1) > 0.5 else p for p in points_list
|
||||
]
|
||||
features_list = clouds.features_list()
|
||||
normals_list = clouds.normals_list()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Pointclouds(
|
||||
points=points_list, features=features_list, normals=normals_list
|
||||
)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
points_list = clouds.points_list()
|
||||
features_list = [
|
||||
f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list
|
||||
]
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Pointclouds(
|
||||
points=points_list, features=features_list, normals=normals_list
|
||||
)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
points_padded = clouds.points_padded() # on cuda:0
|
||||
features_padded = clouds.features_padded().to("cpu")
|
||||
normals_padded = clouds.normals_padded()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Pointclouds(
|
||||
points=points_padded, features=features_padded, normals=normals_padded
|
||||
)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
def test_all_constructions(self):
|
||||
public_getters = [
|
||||
"points_list",
|
||||
|
Loading…
x
Reference in New Issue
Block a user