mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
CPU implem for face areas normals
Summary: Added cpu implementation for face areas normals. Moved test and bm to separate functions. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_2_100_300_False 196 268 2550 FACE_AREAS_NORMALS_2_100_300_True 106 179 4733 FACE_AREAS_NORMALS_2_100_3000_False 1447 1630 346 FACE_AREAS_NORMALS_2_100_3000_True 107 178 4674 FACE_AREAS_NORMALS_2_1000_300_False 201 309 2486 FACE_AREAS_NORMALS_2_1000_300_True 107 186 4673 FACE_AREAS_NORMALS_2_1000_3000_False 1451 1636 345 FACE_AREAS_NORMALS_2_1000_3000_True 107 186 4655 FACE_AREAS_NORMALS_10_100_300_False 767 918 653 FACE_AREAS_NORMALS_10_100_300_True 106 167 4712 FACE_AREAS_NORMALS_10_100_3000_False 7036 7754 72 FACE_AREAS_NORMALS_10_100_3000_True 113 164 4445 FACE_AREAS_NORMALS_10_1000_300_False 748 947 669 FACE_AREAS_NORMALS_10_1000_300_True 108 169 4638 FACE_AREAS_NORMALS_10_1000_3000_False 7069 7783 71 FACE_AREAS_NORMALS_10_1000_3000_True 108 172 4646 FACE_AREAS_NORMALS_32_100_300_False 2286 2496 219 FACE_AREAS_NORMALS_32_100_300_True 108 180 4631 FACE_AREAS_NORMALS_32_100_3000_False 23184 24369 22 FACE_AREAS_NORMALS_32_100_3000_True 159 213 3147 FACE_AREAS_NORMALS_32_1000_300_False 2414 2645 208 FACE_AREAS_NORMALS_32_1000_300_True 112 197 4480 FACE_AREAS_NORMALS_32_1000_3000_False 21687 22964 24 FACE_AREAS_NORMALS_32_1000_3000_True 141 211 3540 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- FACE_AREAS_NORMALS_TORCH_2_100_300_False 5465 5782 92 FACE_AREAS_NORMALS_TORCH_2_100_300_True 1198 1351 418 FACE_AREAS_NORMALS_TORCH_2_100_3000_False 48228 48869 11 FACE_AREAS_NORMALS_TORCH_2_100_3000_True 1186 1304 422 FACE_AREAS_NORMALS_TORCH_2_1000_300_False 5556 6097 90 FACE_AREAS_NORMALS_TORCH_2_1000_300_True 1200 1328 417 FACE_AREAS_NORMALS_TORCH_2_1000_3000_False 48683 50016 11 FACE_AREAS_NORMALS_TORCH_2_1000_3000_True 1185 1306 422 FACE_AREAS_NORMALS_TORCH_10_100_300_False 24215 25097 21 FACE_AREAS_NORMALS_TORCH_10_100_300_True 1150 1314 435 FACE_AREAS_NORMALS_TORCH_10_100_3000_False 232605 234952 3 FACE_AREAS_NORMALS_TORCH_10_100_3000_True 1193 1314 420 FACE_AREAS_NORMALS_TORCH_10_1000_300_False 24912 25343 21 FACE_AREAS_NORMALS_TORCH_10_1000_300_True 1216 1330 412 FACE_AREAS_NORMALS_TORCH_10_1000_3000_False 239907 241253 3 FACE_AREAS_NORMALS_TORCH_10_1000_3000_True 1226 1333 408 FACE_AREAS_NORMALS_TORCH_32_100_300_False 73991 75776 7 FACE_AREAS_NORMALS_TORCH_32_100_300_True 1193 1339 420 FACE_AREAS_NORMALS_TORCH_32_100_3000_False 728932 728932 1 FACE_AREAS_NORMALS_TORCH_32_100_3000_True 1186 1359 422 FACE_AREAS_NORMALS_TORCH_32_1000_300_False 76385 79129 7 FACE_AREAS_NORMALS_TORCH_32_1000_300_True 1165 1310 430 FACE_AREAS_NORMALS_TORCH_32_1000_3000_False 753276 753276 1 FACE_AREAS_NORMALS_TORCH_32_1000_3000_True 1205 1340 415 -------------------------------------------------------------------------------- ``` Reviewed By: bottler, jcjohnson Differential Revision: D19864385 fbshipit-source-id: 3a87ae41a8e3ab5560febcb94961798f2e09dfb8
This commit is contained in:
parent
8fe65d5f56
commit
29cd181a83
@ -9,7 +9,7 @@
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals", &face_areas_normals);
|
||||
m.def("face_areas_normals", &FaceAreasNormals);
|
||||
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
|
||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||
m.def("gather_scatter", &gather_scatter);
|
||||
|
@ -4,7 +4,7 @@
|
||||
#include <tuple>
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void face_areas_kernel(
|
||||
__global__ void FaceAreasNormalsKernel(
|
||||
const scalar_t* __restrict__ verts,
|
||||
const long* __restrict__ faces,
|
||||
scalar_t* __restrict__ face_areas,
|
||||
@ -55,7 +55,7 @@ __global__ void face_areas_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
const auto V = verts.size(0);
|
||||
@ -66,14 +66,15 @@ std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_kernel", ([&] {
|
||||
face_areas_kernel<scalar_t><<<blocks, threads>>>(
|
||||
verts.data_ptr<scalar_t>(),
|
||||
faces.data_ptr<long>(),
|
||||
areas.data_ptr<scalar_t>(),
|
||||
normals.data_ptr<scalar_t>(),
|
||||
V,
|
||||
F);
|
||||
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] {
|
||||
FaceAreasNormalsKernel<scalar_t>
|
||||
<<<blocks, threads>>>(
|
||||
verts.data_ptr<scalar_t>(),
|
||||
faces.data_ptr<long>(),
|
||||
areas.data_ptr<scalar_t>(),
|
||||
normals.data_ptr<scalar_t>(),
|
||||
V,
|
||||
F);
|
||||
}));
|
||||
|
||||
return std::make_tuple(areas, normals);
|
||||
|
@ -16,21 +16,26 @@
|
||||
// faces[f]
|
||||
//
|
||||
|
||||
// Cpu implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
|
||||
// Cuda implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> face_areas_normals(
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return face_areas_cuda(verts, faces);
|
||||
return FaceAreasNormalsCuda(verts, faces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU.");
|
||||
return FaceAreasNormalsCpu(verts, faces);
|
||||
}
|
||||
|
57
pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp
Normal file
57
pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
const int V = verts.size(0);
|
||||
const int F = faces.size(0);
|
||||
|
||||
at::Tensor areas = at::empty({F}, verts.options());
|
||||
at::Tensor normals = at::empty({F, 3}, verts.options());
|
||||
|
||||
auto verts_a = verts.accessor<float, 2>();
|
||||
auto faces_a = faces.accessor<int64_t, 2>();
|
||||
auto areas_a = areas.accessor<float, 1>();
|
||||
auto normals_a = normals.accessor<float, 2>();
|
||||
|
||||
for (int f = 0; f < F; ++f) {
|
||||
const int64_t i0 = faces_a[f][0];
|
||||
const int64_t i1 = faces_a[f][1];
|
||||
const int64_t i2 = faces_a[f][2];
|
||||
|
||||
const float v0_x = verts_a[i0][0];
|
||||
const float v0_y = verts_a[i0][1];
|
||||
const float v0_z = verts_a[i0][2];
|
||||
|
||||
const float v1_x = verts_a[i1][0];
|
||||
const float v1_y = verts_a[i1][1];
|
||||
const float v1_z = verts_a[i1][2];
|
||||
|
||||
const float v2_x = verts_a[i2][0];
|
||||
const float v2_y = verts_a[i2][1];
|
||||
const float v2_z = verts_a[i2][2];
|
||||
|
||||
const float ax = v1_x - v0_x;
|
||||
const float ay = v1_y - v0_y;
|
||||
const float az = v1_z - v0_z;
|
||||
|
||||
const float bx = v2_x - v0_x;
|
||||
const float by = v2_y - v0_y;
|
||||
const float bz = v2_z - v0_z;
|
||||
|
||||
const float cx = ay * bz - az * by;
|
||||
const float cy = az * bx - ax * bz;
|
||||
const float cz = ax * by - ay * bx;
|
||||
|
||||
float norm = sqrt(cx * cx + cy * cy + cz * cz);
|
||||
areas_a[f] = norm / 2.0;
|
||||
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
|
||||
normals_a[f][0] = cx / norm;
|
||||
normals_a[f][1] = cy / norm;
|
||||
normals_a[f][2] = cz / norm;
|
||||
}
|
||||
return std::make_tuple(areas, normals);
|
||||
}
|
@ -771,21 +771,9 @@ class Meshes(object):
|
||||
return
|
||||
faces_packed = self.faces_packed()
|
||||
verts_packed = self.verts_packed()
|
||||
if verts_packed.is_cuda and faces_packed.is_cuda:
|
||||
face_areas, face_normals = _C.face_areas_normals(
|
||||
verts_packed, faces_packed
|
||||
)
|
||||
else:
|
||||
vertices_faces = verts_packed[faces_packed] # (F, 3, 3)
|
||||
# vector pointing from v0 to v1
|
||||
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
||||
# vector pointing from v0 to v2
|
||||
v02 = vertices_faces[:, 2] - vertices_faces[:, 0]
|
||||
normals = torch.cross(v01, v02, dim=1) # (F, 3)
|
||||
face_areas = normals.norm(dim=-1) / 2
|
||||
face_normals = torch.nn.functional.normalize(
|
||||
normals, p=2, dim=1, eps=1e-6
|
||||
)
|
||||
face_areas, face_normals = _C.face_areas_normals(
|
||||
verts_packed, faces_packed
|
||||
)
|
||||
self._faces_areas_packed = face_areas
|
||||
self._faces_normals_packed = face_normals
|
||||
|
||||
|
40
tests/bm_face_areas_normals.py
Normal file
40
tests/bm_face_areas_normals.py
Normal file
@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from itertools import product
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from test_face_areas_normals import TestFaceAreasNormals
|
||||
|
||||
|
||||
def bm_face_areas_normals() -> None:
|
||||
kwargs_list = []
|
||||
backend_cuda = [False]
|
||||
if torch.cuda.is_available():
|
||||
backend_cuda.append(True)
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
|
||||
for case in test_cases:
|
||||
n, v, f, c = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
|
||||
)
|
||||
benchmark(
|
||||
TestFaceAreasNormals.face_areas_normals_with_init,
|
||||
"FACE_AREAS_NORMALS",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestFaceAreasNormals.face_areas_normals_with_init_torch,
|
||||
"FACE_AREAS_NORMALS_TORCH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
118
tests/test_face_areas_normals.py
Normal file
118
tests/test_face_areas_normals.py
Normal file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def init_meshes(
|
||||
num_meshes: int = 10,
|
||||
num_verts: int = 1000,
|
||||
num_faces: int = 3000,
|
||||
device: str = "cpu",
|
||||
):
|
||||
device = torch.device(device)
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = torch.rand(
|
||||
(num_verts, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
faces = torch.randint(
|
||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||
)
|
||||
verts_list.append(verts)
|
||||
faces_list.append(faces)
|
||||
meshes = Meshes(verts_list, faces_list)
|
||||
|
||||
return meshes
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals(verts, faces):
|
||||
"""
|
||||
Pytorch implementation for face areas & normals.
|
||||
"""
|
||||
vertices_faces = verts[faces] # (F, 3, 3)
|
||||
# vector pointing from v0 to v1
|
||||
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
||||
# vector pointing from v0 to v2
|
||||
v02 = vertices_faces[:, 2] - vertices_faces[:, 0]
|
||||
normals = torch.cross(v01, v02, dim=1) # (F, 3)
|
||||
face_areas = normals.norm(dim=-1) / 2
|
||||
face_normals = torch.nn.functional.normalize(
|
||||
normals, p=2, dim=1, eps=1e-6
|
||||
)
|
||||
return face_areas, face_normals
|
||||
|
||||
def _test_face_areas_normals_helper(self, device):
|
||||
"""
|
||||
Check the results from face_areas cuda/cpp and PyTorch implementation are
|
||||
the same.
|
||||
"""
|
||||
meshes = self.init_meshes(10, 1000, 3000, device=device)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
|
||||
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
|
||||
areas, normals = _C.face_areas_normals(verts, faces)
|
||||
self.assertClose(areas_torch, areas, atol=1e-7)
|
||||
# normals get normalized by area thus sensitivity increases as areas
|
||||
# in our tests can be arbitrarily small. Thus we compare normals after
|
||||
# multiplying with areas
|
||||
unnormals = normals * areas.view(-1, 1)
|
||||
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
|
||||
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
|
||||
|
||||
def test_face_areas_normals_cpu(self):
|
||||
self._test_face_areas_normals_helper("cpu")
|
||||
|
||||
def test_face_areas_normals_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0")
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
||||
):
|
||||
device = "cuda:0" if cuda else "cpu"
|
||||
meshes = TestFaceAreasNormals.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas_normals():
|
||||
_C.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas_normals
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init_torch(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
||||
):
|
||||
device = "cuda:0" if cuda else "cpu"
|
||||
meshes = TestFaceAreasNormals.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas_normals():
|
||||
TestFaceAreasNormals.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas_normals
|
@ -294,58 +294,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def face_areas(verts, faces):
|
||||
"""
|
||||
Vectorized PyTorch implementation of triangle face area function.
|
||||
"""
|
||||
verts_faces = verts[faces]
|
||||
v0x = verts_faces[:, 0::3, 0]
|
||||
v0y = verts_faces[:, 0::3, 1]
|
||||
v0z = verts_faces[:, 0::3, 2]
|
||||
|
||||
v1x = verts_faces[:, 1::3, 0]
|
||||
v1y = verts_faces[:, 1::3, 1]
|
||||
v1z = verts_faces[:, 1::3, 2]
|
||||
|
||||
v2x = verts_faces[:, 2::3, 0]
|
||||
v2y = verts_faces[:, 2::3, 1]
|
||||
v2z = verts_faces[:, 2::3, 2]
|
||||
|
||||
ax = v0x - v2x
|
||||
ay = v0y - v2y
|
||||
az = v0z - v2z
|
||||
|
||||
bx = v1x - v2x
|
||||
by = v1y - v2y
|
||||
bz = v1z - v2z
|
||||
|
||||
cx = ay * bz - az * by
|
||||
cy = az * bx - ax * bz
|
||||
cz = ax * by - ay * bx
|
||||
|
||||
# this gives the area of the parallelogram with sides a and b
|
||||
area_sqr = cx * cx + cy * cy + cz * cz
|
||||
# the area of the triangle is half
|
||||
return torch.sqrt(area_sqr) / 2.0
|
||||
|
||||
def test_face_areas(self):
|
||||
"""
|
||||
Check the results from face_areas cuda and PyTorch implementions are
|
||||
the same. Check that face_areas throws an error if cpu tensors are
|
||||
given as input.
|
||||
"""
|
||||
meshes = self.init_meshes(10, 1000, 3000, device="cuda:0")
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
|
||||
areas_torch = self.face_areas(verts, faces).squeeze()
|
||||
areas_cuda, _ = _C.face_areas_normals(verts, faces)
|
||||
self.assertTrue(torch.allclose(areas_torch, areas_cuda, atol=5e-8))
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.face_areas_normals(verts.cpu(), faces.cpu())
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_tensor(inputs, first_idxs, max_size):
|
||||
"""
|
||||
@ -419,27 +367,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
|
||||
return sample_points
|
||||
|
||||
@staticmethod
|
||||
def face_areas_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
):
|
||||
device = "cuda" if cuda else "cpu"
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas():
|
||||
if cuda:
|
||||
_C.face_areas_normals(verts, faces)
|
||||
else:
|
||||
TestSamplePoints.face_areas(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
@ -453,10 +380,7 @@ class TestSamplePoints(unittest.TestCase):
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
if cuda:
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
else:
|
||||
areas = TestSamplePoints.face_areas(verts, faces)
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def packed_to_padded():
|
||||
|
Loading…
x
Reference in New Issue
Block a user