mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/384 Test Plan: `test_meshes` and `test_points` Reviewed By: gkioxari Differential Revision: D24730524 Pulled By: nikhilaravi fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
This commit is contained in:
committed by
Facebook GitHub Bot
parent
19340462e4
commit
569e5229a9
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -162,6 +163,29 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
torch.tensor([0, 3, 8], dtype=torch.int64),
|
||||
)
|
||||
|
||||
def test_init_error(self):
|
||||
# Check if correct errors are raised when verts/faces are on
|
||||
# different devices
|
||||
|
||||
mesh = TestMeshes.init_mesh(10, 10, 100)
|
||||
verts_list = mesh.verts_list() # all tensors on cpu
|
||||
verts_list = [
|
||||
v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list
|
||||
]
|
||||
faces_list = mesh.faces_list()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Meshes(verts=verts_list, faces=faces_list)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
verts_padded = mesh.verts_padded() # on cpu
|
||||
verts_padded = verts_padded.to("cuda:0")
|
||||
faces_padded = mesh.faces_padded()
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
Meshes(verts=verts_padded, faces=faces_padded)
|
||||
self.assertTrue("same device" in cm.msg)
|
||||
|
||||
def test_simple_random_meshes(self):
|
||||
|
||||
# Define the test mesh object either as a list or tensor of faces/verts.
|
||||
|
||||
Reference in New Issue
Block a user