mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Added cpu implementation for face areas normals. Moved test and bm to separate functions. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_2_100_300_False 196 268 2550 FACE_AREAS_NORMALS_2_100_300_True 106 179 4733 FACE_AREAS_NORMALS_2_100_3000_False 1447 1630 346 FACE_AREAS_NORMALS_2_100_3000_True 107 178 4674 FACE_AREAS_NORMALS_2_1000_300_False 201 309 2486 FACE_AREAS_NORMALS_2_1000_300_True 107 186 4673 FACE_AREAS_NORMALS_2_1000_3000_False 1451 1636 345 FACE_AREAS_NORMALS_2_1000_3000_True 107 186 4655 FACE_AREAS_NORMALS_10_100_300_False 767 918 653 FACE_AREAS_NORMALS_10_100_300_True 106 167 4712 FACE_AREAS_NORMALS_10_100_3000_False 7036 7754 72 FACE_AREAS_NORMALS_10_100_3000_True 113 164 4445 FACE_AREAS_NORMALS_10_1000_300_False 748 947 669 FACE_AREAS_NORMALS_10_1000_300_True 108 169 4638 FACE_AREAS_NORMALS_10_1000_3000_False 7069 7783 71 FACE_AREAS_NORMALS_10_1000_3000_True 108 172 4646 FACE_AREAS_NORMALS_32_100_300_False 2286 2496 219 FACE_AREAS_NORMALS_32_100_300_True 108 180 4631 FACE_AREAS_NORMALS_32_100_3000_False 23184 24369 22 FACE_AREAS_NORMALS_32_100_3000_True 159 213 3147 FACE_AREAS_NORMALS_32_1000_300_False 2414 2645 208 FACE_AREAS_NORMALS_32_1000_300_True 112 197 4480 FACE_AREAS_NORMALS_32_1000_3000_False 21687 22964 24 FACE_AREAS_NORMALS_32_1000_3000_True 141 211 3540 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_TORCH_2_100_300_False 5465 5782 92 FACE_AREAS_NORMALS_TORCH_2_100_300_True 1198 1351 418 FACE_AREAS_NORMALS_TORCH_2_100_3000_False 48228 48869 11 FACE_AREAS_NORMALS_TORCH_2_100_3000_True 1186 1304 422 FACE_AREAS_NORMALS_TORCH_2_1000_300_False 5556 6097 90 FACE_AREAS_NORMALS_TORCH_2_1000_300_True 1200 1328 417 FACE_AREAS_NORMALS_TORCH_2_1000_3000_False 48683 50016 11 FACE_AREAS_NORMALS_TORCH_2_1000_3000_True 1185 1306 422 FACE_AREAS_NORMALS_TORCH_10_100_300_False 24215 25097 21 FACE_AREAS_NORMALS_TORCH_10_100_300_True 1150 1314 435 FACE_AREAS_NORMALS_TORCH_10_100_3000_False 232605 234952 3 FACE_AREAS_NORMALS_TORCH_10_100_3000_True 1193 1314 420 FACE_AREAS_NORMALS_TORCH_10_1000_300_False 24912 25343 21 FACE_AREAS_NORMALS_TORCH_10_1000_300_True 1216 1330 412 FACE_AREAS_NORMALS_TORCH_10_1000_3000_False 239907 241253 3 FACE_AREAS_NORMALS_TORCH_10_1000_3000_True 1226 1333 408 FACE_AREAS_NORMALS_TORCH_32_100_300_False 73991 75776 7 FACE_AREAS_NORMALS_TORCH_32_100_300_True 1193 1339 420 FACE_AREAS_NORMALS_TORCH_32_100_3000_False 728932 728932 1 FACE_AREAS_NORMALS_TORCH_32_100_3000_True 1186 1359 422 FACE_AREAS_NORMALS_TORCH_32_1000_300_False 76385 79129 7 FACE_AREAS_NORMALS_TORCH_32_1000_300_True 1165 1310 430 FACE_AREAS_NORMALS_TORCH_32_1000_3000_False 753276 753276 1 FACE_AREAS_NORMALS_TORCH_32_1000_3000_True 1205 1340 415 -------------------------------------------------------------------------------- ``` Reviewed By: bottler, jcjohnson Differential Revision: D19864385 fbshipit-source-id: 3a87ae41a8e3ab5560febcb94961798f2e09dfb8
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
|
|
import unittest
|
|
import torch
|
|
|
|
from pytorch3d import _C
|
|
from pytorch3d.structures.meshes import Meshes
|
|
|
|
from common_testing import TestCaseMixin
|
|
|
|
|
|
class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(1)
|
|
|
|
@staticmethod
|
|
def init_meshes(
|
|
num_meshes: int = 10,
|
|
num_verts: int = 1000,
|
|
num_faces: int = 3000,
|
|
device: str = "cpu",
|
|
):
|
|
device = torch.device(device)
|
|
verts_list = []
|
|
faces_list = []
|
|
for _ in range(num_meshes):
|
|
verts = torch.rand(
|
|
(num_verts, 3), dtype=torch.float32, device=device
|
|
)
|
|
faces = torch.randint(
|
|
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
|
)
|
|
verts_list.append(verts)
|
|
faces_list.append(faces)
|
|
meshes = Meshes(verts_list, faces_list)
|
|
|
|
return meshes
|
|
|
|
@staticmethod
|
|
def face_areas_normals(verts, faces):
|
|
"""
|
|
Pytorch implementation for face areas & normals.
|
|
"""
|
|
vertices_faces = verts[faces] # (F, 3, 3)
|
|
# vector pointing from v0 to v1
|
|
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
|
# vector pointing from v0 to v2
|
|
v02 = vertices_faces[:, 2] - vertices_faces[:, 0]
|
|
normals = torch.cross(v01, v02, dim=1) # (F, 3)
|
|
face_areas = normals.norm(dim=-1) / 2
|
|
face_normals = torch.nn.functional.normalize(
|
|
normals, p=2, dim=1, eps=1e-6
|
|
)
|
|
return face_areas, face_normals
|
|
|
|
def _test_face_areas_normals_helper(self, device):
|
|
"""
|
|
Check the results from face_areas cuda/cpp and PyTorch implementation are
|
|
the same.
|
|
"""
|
|
meshes = self.init_meshes(10, 1000, 3000, device=device)
|
|
verts = meshes.verts_packed()
|
|
faces = meshes.faces_packed()
|
|
|
|
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
|
|
areas, normals = _C.face_areas_normals(verts, faces)
|
|
self.assertClose(areas_torch, areas, atol=1e-7)
|
|
# normals get normalized by area thus sensitivity increases as areas
|
|
# in our tests can be arbitrarily small. Thus we compare normals after
|
|
# multiplying with areas
|
|
unnormals = normals * areas.view(-1, 1)
|
|
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
|
|
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
|
|
|
|
def test_face_areas_normals_cpu(self):
|
|
self._test_face_areas_normals_helper("cpu")
|
|
|
|
def test_face_areas_normals_cuda(self):
|
|
self._test_face_areas_normals_helper("cuda:0")
|
|
|
|
@staticmethod
|
|
def face_areas_normals_with_init(
|
|
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
|
):
|
|
device = "cuda:0" if cuda else "cpu"
|
|
meshes = TestFaceAreasNormals.init_meshes(
|
|
num_meshes, num_verts, num_faces, device
|
|
)
|
|
verts = meshes.verts_packed()
|
|
faces = meshes.faces_packed()
|
|
torch.cuda.synchronize()
|
|
|
|
def face_areas_normals():
|
|
_C.face_areas_normals(verts, faces)
|
|
torch.cuda.synchronize()
|
|
|
|
return face_areas_normals
|
|
|
|
@staticmethod
|
|
def face_areas_normals_with_init_torch(
|
|
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
|
):
|
|
device = "cuda:0" if cuda else "cpu"
|
|
meshes = TestFaceAreasNormals.init_meshes(
|
|
num_meshes, num_verts, num_faces, device
|
|
)
|
|
verts = meshes.verts_packed()
|
|
faces = meshes.faces_packed()
|
|
torch.cuda.synchronize()
|
|
|
|
def face_areas_normals():
|
|
TestFaceAreasNormals.face_areas_normals(verts, faces)
|
|
torch.cuda.synchronize()
|
|
|
|
return face_areas_normals
|