mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
CPU implem for face areas normals
Summary: Added cpu implementation for face areas normals. Moved test and bm to separate functions. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_2_100_300_False 196 268 2550 FACE_AREAS_NORMALS_2_100_300_True 106 179 4733 FACE_AREAS_NORMALS_2_100_3000_False 1447 1630 346 FACE_AREAS_NORMALS_2_100_3000_True 107 178 4674 FACE_AREAS_NORMALS_2_1000_300_False 201 309 2486 FACE_AREAS_NORMALS_2_1000_300_True 107 186 4673 FACE_AREAS_NORMALS_2_1000_3000_False 1451 1636 345 FACE_AREAS_NORMALS_2_1000_3000_True 107 186 4655 FACE_AREAS_NORMALS_10_100_300_False 767 918 653 FACE_AREAS_NORMALS_10_100_300_True 106 167 4712 FACE_AREAS_NORMALS_10_100_3000_False 7036 7754 72 FACE_AREAS_NORMALS_10_100_3000_True 113 164 4445 FACE_AREAS_NORMALS_10_1000_300_False 748 947 669 FACE_AREAS_NORMALS_10_1000_300_True 108 169 4638 FACE_AREAS_NORMALS_10_1000_3000_False 7069 7783 71 FACE_AREAS_NORMALS_10_1000_3000_True 108 172 4646 FACE_AREAS_NORMALS_32_100_300_False 2286 2496 219 FACE_AREAS_NORMALS_32_100_300_True 108 180 4631 FACE_AREAS_NORMALS_32_100_3000_False 23184 24369 22 FACE_AREAS_NORMALS_32_100_3000_True 159 213 3147 FACE_AREAS_NORMALS_32_1000_300_False 2414 2645 208 FACE_AREAS_NORMALS_32_1000_300_True 112 197 4480 FACE_AREAS_NORMALS_32_1000_3000_False 21687 22964 24 FACE_AREAS_NORMALS_32_1000_3000_True 141 211 3540 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_TORCH_2_100_300_False 5465 5782 92 FACE_AREAS_NORMALS_TORCH_2_100_300_True 1198 1351 418 FACE_AREAS_NORMALS_TORCH_2_100_3000_False 48228 48869 11 FACE_AREAS_NORMALS_TORCH_2_100_3000_True 1186 1304 422 FACE_AREAS_NORMALS_TORCH_2_1000_300_False 5556 6097 90 FACE_AREAS_NORMALS_TORCH_2_1000_300_True 1200 1328 417 FACE_AREAS_NORMALS_TORCH_2_1000_3000_False 48683 50016 11 FACE_AREAS_NORMALS_TORCH_2_1000_3000_True 1185 1306 422 FACE_AREAS_NORMALS_TORCH_10_100_300_False 24215 25097 21 FACE_AREAS_NORMALS_TORCH_10_100_300_True 1150 1314 435 FACE_AREAS_NORMALS_TORCH_10_100_3000_False 232605 234952 3 FACE_AREAS_NORMALS_TORCH_10_100_3000_True 1193 1314 420 FACE_AREAS_NORMALS_TORCH_10_1000_300_False 24912 25343 21 FACE_AREAS_NORMALS_TORCH_10_1000_300_True 1216 1330 412 FACE_AREAS_NORMALS_TORCH_10_1000_3000_False 239907 241253 3 FACE_AREAS_NORMALS_TORCH_10_1000_3000_True 1226 1333 408 FACE_AREAS_NORMALS_TORCH_32_100_300_False 73991 75776 7 FACE_AREAS_NORMALS_TORCH_32_100_300_True 1193 1339 420 FACE_AREAS_NORMALS_TORCH_32_100_3000_False 728932 728932 1 FACE_AREAS_NORMALS_TORCH_32_100_3000_True 1186 1359 422 FACE_AREAS_NORMALS_TORCH_32_1000_300_False 76385 79129 7 FACE_AREAS_NORMALS_TORCH_32_1000_300_True 1165 1310 430 FACE_AREAS_NORMALS_TORCH_32_1000_3000_False 753276 753276 1 FACE_AREAS_NORMALS_TORCH_32_1000_3000_True 1205 1340 415 -------------------------------------------------------------------------------- ``` Reviewed By: bottler, jcjohnson Differential Revision: D19864385 fbshipit-source-id: 3a87ae41a8e3ab5560febcb94961798f2e09dfb8
This commit is contained in:
committed by
Facebook Github Bot
parent
8fe65d5f56
commit
29cd181a83
@@ -294,58 +294,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def face_areas(verts, faces):
|
||||
"""
|
||||
Vectorized PyTorch implementation of triangle face area function.
|
||||
"""
|
||||
verts_faces = verts[faces]
|
||||
v0x = verts_faces[:, 0::3, 0]
|
||||
v0y = verts_faces[:, 0::3, 1]
|
||||
v0z = verts_faces[:, 0::3, 2]
|
||||
|
||||
v1x = verts_faces[:, 1::3, 0]
|
||||
v1y = verts_faces[:, 1::3, 1]
|
||||
v1z = verts_faces[:, 1::3, 2]
|
||||
|
||||
v2x = verts_faces[:, 2::3, 0]
|
||||
v2y = verts_faces[:, 2::3, 1]
|
||||
v2z = verts_faces[:, 2::3, 2]
|
||||
|
||||
ax = v0x - v2x
|
||||
ay = v0y - v2y
|
||||
az = v0z - v2z
|
||||
|
||||
bx = v1x - v2x
|
||||
by = v1y - v2y
|
||||
bz = v1z - v2z
|
||||
|
||||
cx = ay * bz - az * by
|
||||
cy = az * bx - ax * bz
|
||||
cz = ax * by - ay * bx
|
||||
|
||||
# this gives the area of the parallelogram with sides a and b
|
||||
area_sqr = cx * cx + cy * cy + cz * cz
|
||||
# the area of the triangle is half
|
||||
return torch.sqrt(area_sqr) / 2.0
|
||||
|
||||
def test_face_areas(self):
|
||||
"""
|
||||
Check the results from face_areas cuda and PyTorch implementions are
|
||||
the same. Check that face_areas throws an error if cpu tensors are
|
||||
given as input.
|
||||
"""
|
||||
meshes = self.init_meshes(10, 1000, 3000, device="cuda:0")
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
|
||||
areas_torch = self.face_areas(verts, faces).squeeze()
|
||||
areas_cuda, _ = _C.face_areas_normals(verts, faces)
|
||||
self.assertTrue(torch.allclose(areas_torch, areas_cuda, atol=5e-8))
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.face_areas_normals(verts.cpu(), faces.cpu())
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_tensor(inputs, first_idxs, max_size):
|
||||
"""
|
||||
@@ -419,27 +367,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
|
||||
return sample_points
|
||||
|
||||
@staticmethod
|
||||
def face_areas_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
):
|
||||
device = "cuda" if cuda else "cpu"
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas():
|
||||
if cuda:
|
||||
_C.face_areas_normals(verts, faces)
|
||||
else:
|
||||
TestSamplePoints.face_areas(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
@@ -453,10 +380,7 @@ class TestSamplePoints(unittest.TestCase):
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
if cuda:
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
else:
|
||||
areas = TestSamplePoints.face_areas(verts, faces)
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def packed_to_padded():
|
||||
|
||||
Reference in New Issue
Block a user