CPU implem for face areas normals

Summary:
Added cpu implementation for face areas normals. Moved test and bm to separate functions.

```
Benchmark                                   Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
FACE_AREAS_NORMALS_2_100_300_False                196             268           2550
FACE_AREAS_NORMALS_2_100_300_True                 106             179           4733
FACE_AREAS_NORMALS_2_100_3000_False              1447            1630            346
FACE_AREAS_NORMALS_2_100_3000_True                107             178           4674
FACE_AREAS_NORMALS_2_1000_300_False               201             309           2486
FACE_AREAS_NORMALS_2_1000_300_True                107             186           4673
FACE_AREAS_NORMALS_2_1000_3000_False             1451            1636            345
FACE_AREAS_NORMALS_2_1000_3000_True               107             186           4655
FACE_AREAS_NORMALS_10_100_300_False               767             918            653
FACE_AREAS_NORMALS_10_100_300_True                106             167           4712
FACE_AREAS_NORMALS_10_100_3000_False             7036            7754             72
FACE_AREAS_NORMALS_10_100_3000_True               113             164           4445
FACE_AREAS_NORMALS_10_1000_300_False              748             947            669
FACE_AREAS_NORMALS_10_1000_300_True               108             169           4638
FACE_AREAS_NORMALS_10_1000_3000_False            7069            7783             71
FACE_AREAS_NORMALS_10_1000_3000_True              108             172           4646
FACE_AREAS_NORMALS_32_100_300_False              2286            2496            219
FACE_AREAS_NORMALS_32_100_300_True                108             180           4631
FACE_AREAS_NORMALS_32_100_3000_False            23184           24369             22
FACE_AREAS_NORMALS_32_100_3000_True               159             213           3147
FACE_AREAS_NORMALS_32_1000_300_False             2414            2645            208
FACE_AREAS_NORMALS_32_1000_300_True               112             197           4480
FACE_AREAS_NORMALS_32_1000_3000_False           21687           22964             24
FACE_AREAS_NORMALS_32_1000_3000_True              141             211           3540
--------------------------------------------------------------------------------

Benchmark                                         Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
FACE_AREAS_NORMALS_TORCH_2_100_300_False               5465            5782             92
FACE_AREAS_NORMALS_TORCH_2_100_300_True                1198            1351            418
FACE_AREAS_NORMALS_TORCH_2_100_3000_False             48228           48869             11
FACE_AREAS_NORMALS_TORCH_2_100_3000_True               1186            1304            422
FACE_AREAS_NORMALS_TORCH_2_1000_300_False              5556            6097             90
FACE_AREAS_NORMALS_TORCH_2_1000_300_True               1200            1328            417
FACE_AREAS_NORMALS_TORCH_2_1000_3000_False            48683           50016             11
FACE_AREAS_NORMALS_TORCH_2_1000_3000_True              1185            1306            422
FACE_AREAS_NORMALS_TORCH_10_100_300_False             24215           25097             21
FACE_AREAS_NORMALS_TORCH_10_100_300_True               1150            1314            435
FACE_AREAS_NORMALS_TORCH_10_100_3000_False           232605          234952              3
FACE_AREAS_NORMALS_TORCH_10_100_3000_True              1193            1314            420
FACE_AREAS_NORMALS_TORCH_10_1000_300_False            24912           25343             21
FACE_AREAS_NORMALS_TORCH_10_1000_300_True              1216            1330            412
FACE_AREAS_NORMALS_TORCH_10_1000_3000_False          239907          241253              3
FACE_AREAS_NORMALS_TORCH_10_1000_3000_True             1226            1333            408
FACE_AREAS_NORMALS_TORCH_32_100_300_False             73991           75776              7
FACE_AREAS_NORMALS_TORCH_32_100_300_True               1193            1339            420
FACE_AREAS_NORMALS_TORCH_32_100_3000_False           728932          728932              1
FACE_AREAS_NORMALS_TORCH_32_100_3000_True              1186            1359            422
FACE_AREAS_NORMALS_TORCH_32_1000_300_False            76385           79129              7
FACE_AREAS_NORMALS_TORCH_32_1000_300_True              1165            1310            430
FACE_AREAS_NORMALS_TORCH_32_1000_3000_False          753276          753276              1
FACE_AREAS_NORMALS_TORCH_32_1000_3000_True             1205            1340            415
--------------------------------------------------------------------------------
```

Reviewed By: bottler, jcjohnson

Differential Revision: D19864385

fbshipit-source-id: 3a87ae41a8e3ab5560febcb94961798f2e09dfb8
This commit is contained in:
Georgia Gkioxari 2020-02-13 11:40:52 -08:00 committed by Facebook Github Bot
parent 8fe65d5f56
commit 29cd181a83
8 changed files with 240 additions and 107 deletions

View File

@ -9,7 +9,7 @@
#include "rasterize_points/rasterize_points.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals", &face_areas_normals);
m.def("face_areas_normals", &FaceAreasNormals);
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);

View File

@ -4,7 +4,7 @@
#include <tuple>
template <typename scalar_t>
__global__ void face_areas_kernel(
__global__ void FaceAreasNormalsKernel(
const scalar_t* __restrict__ verts,
const long* __restrict__ faces,
scalar_t* __restrict__ face_areas,
@ -55,7 +55,7 @@ __global__ void face_areas_kernel(
}
}
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
at::Tensor verts,
at::Tensor faces) {
const auto V = verts.size(0);
@ -66,14 +66,15 @@ std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
const int blocks = 64;
const int threads = 512;
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_kernel", ([&] {
face_areas_kernel<scalar_t><<<blocks, threads>>>(
verts.data_ptr<scalar_t>(),
faces.data_ptr<long>(),
areas.data_ptr<scalar_t>(),
normals.data_ptr<scalar_t>(),
V,
F);
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] {
FaceAreasNormalsKernel<scalar_t>
<<<blocks, threads>>>(
verts.data_ptr<scalar_t>(),
faces.data_ptr<long>(),
areas.data_ptr<scalar_t>(),
normals.data_ptr<scalar_t>(),
V,
F);
}));
return std::make_tuple(areas, normals);

View File

@ -16,21 +16,26 @@
// faces[f]
//
// Cpu implementation.
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
at::Tensor verts,
at::Tensor faces);
// Cuda implementation.
std::tuple<at::Tensor, at::Tensor> face_areas_cuda(
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
at::Tensor verts,
at::Tensor faces);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> face_areas_normals(
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(
at::Tensor verts,
at::Tensor faces) {
if (verts.type().is_cuda() && faces.type().is_cuda()) {
#ifdef WITH_CUDA
return face_areas_cuda(verts, faces);
return FaceAreasNormalsCuda(verts, faces);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU.");
return FaceAreasNormalsCpu(verts, faces);
}

View File

@ -0,0 +1,57 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <tuple>
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
at::Tensor verts,
at::Tensor faces) {
const int V = verts.size(0);
const int F = faces.size(0);
at::Tensor areas = at::empty({F}, verts.options());
at::Tensor normals = at::empty({F, 3}, verts.options());
auto verts_a = verts.accessor<float, 2>();
auto faces_a = faces.accessor<int64_t, 2>();
auto areas_a = areas.accessor<float, 1>();
auto normals_a = normals.accessor<float, 2>();
for (int f = 0; f < F; ++f) {
const int64_t i0 = faces_a[f][0];
const int64_t i1 = faces_a[f][1];
const int64_t i2 = faces_a[f][2];
const float v0_x = verts_a[i0][0];
const float v0_y = verts_a[i0][1];
const float v0_z = verts_a[i0][2];
const float v1_x = verts_a[i1][0];
const float v1_y = verts_a[i1][1];
const float v1_z = verts_a[i1][2];
const float v2_x = verts_a[i2][0];
const float v2_y = verts_a[i2][1];
const float v2_z = verts_a[i2][2];
const float ax = v1_x - v0_x;
const float ay = v1_y - v0_y;
const float az = v1_z - v0_z;
const float bx = v2_x - v0_x;
const float by = v2_y - v0_y;
const float bz = v2_z - v0_z;
const float cx = ay * bz - az * by;
const float cy = az * bx - ax * bz;
const float cz = ax * by - ay * bx;
float norm = sqrt(cx * cx + cy * cy + cz * cz);
areas_a[f] = norm / 2.0;
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
normals_a[f][0] = cx / norm;
normals_a[f][1] = cy / norm;
normals_a[f][2] = cz / norm;
}
return std::make_tuple(areas, normals);
}

View File

@ -771,21 +771,9 @@ class Meshes(object):
return
faces_packed = self.faces_packed()
verts_packed = self.verts_packed()
if verts_packed.is_cuda and faces_packed.is_cuda:
face_areas, face_normals = _C.face_areas_normals(
verts_packed, faces_packed
)
else:
vertices_faces = verts_packed[faces_packed] # (F, 3, 3)
# vector pointing from v0 to v1
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
# vector pointing from v0 to v2
v02 = vertices_faces[:, 2] - vertices_faces[:, 0]
normals = torch.cross(v01, v02, dim=1) # (F, 3)
face_areas = normals.norm(dim=-1) / 2
face_normals = torch.nn.functional.normalize(
normals, p=2, dim=1, eps=1e-6
)
face_areas, face_normals = _C.face_areas_normals(
verts_packed, faces_packed
)
self._faces_areas_packed = face_areas
self._faces_normals_packed = face_normals

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from test_face_areas_normals import TestFaceAreasNormals
def bm_face_areas_normals() -> None:
kwargs_list = []
backend_cuda = [False]
if torch.cuda.is_available():
backend_cuda.append(True)
num_meshes = [2, 10, 32]
num_verts = [100, 1000]
num_faces = [300, 3000]
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
for case in test_cases:
n, v, f, c = case
kwargs_list.append(
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
)
benchmark(
TestFaceAreasNormals.face_areas_normals_with_init,
"FACE_AREAS_NORMALS",
kwargs_list,
warmup_iters=1,
)
benchmark(
TestFaceAreasNormals.face_areas_normals_with_init_torch,
"FACE_AREAS_NORMALS_TORCH",
kwargs_list,
warmup_iters=1,
)

View File

@ -0,0 +1,118 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d import _C
from pytorch3d.structures.meshes import Meshes
from common_testing import TestCaseMixin
class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def init_meshes(
num_meshes: int = 10,
num_verts: int = 1000,
num_faces: int = 3000,
device: str = "cpu",
):
device = torch.device(device)
verts_list = []
faces_list = []
for _ in range(num_meshes):
verts = torch.rand(
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts_list, faces_list)
return meshes
@staticmethod
def face_areas_normals(verts, faces):
"""
Pytorch implementation for face areas & normals.
"""
vertices_faces = verts[faces] # (F, 3, 3)
# vector pointing from v0 to v1
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
# vector pointing from v0 to v2
v02 = vertices_faces[:, 2] - vertices_faces[:, 0]
normals = torch.cross(v01, v02, dim=1) # (F, 3)
face_areas = normals.norm(dim=-1) / 2
face_normals = torch.nn.functional.normalize(
normals, p=2, dim=1, eps=1e-6
)
return face_areas, face_normals
def _test_face_areas_normals_helper(self, device):
"""
Check the results from face_areas cuda/cpp and PyTorch implementation are
the same.
"""
meshes = self.init_meshes(10, 1000, 3000, device=device)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
areas, normals = _C.face_areas_normals(verts, faces)
self.assertClose(areas_torch, areas, atol=1e-7)
# normals get normalized by area thus sensitivity increases as areas
# in our tests can be arbitrarily small. Thus we compare normals after
# multiplying with areas
unnormals = normals * areas.view(-1, 1)
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
def test_face_areas_normals_cpu(self):
self._test_face_areas_normals_helper("cpu")
def test_face_areas_normals_cuda(self):
self._test_face_areas_normals_helper("cuda:0")
@staticmethod
def face_areas_normals_with_init(
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
):
device = "cuda:0" if cuda else "cpu"
meshes = TestFaceAreasNormals.init_meshes(
num_meshes, num_verts, num_faces, device
)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
torch.cuda.synchronize()
def face_areas_normals():
_C.face_areas_normals(verts, faces)
torch.cuda.synchronize()
return face_areas_normals
@staticmethod
def face_areas_normals_with_init_torch(
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
):
device = "cuda:0" if cuda else "cpu"
meshes = TestFaceAreasNormals.init_meshes(
num_meshes, num_verts, num_faces, device
)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
torch.cuda.synchronize()
def face_areas_normals():
TestFaceAreasNormals.face_areas_normals(verts, faces)
torch.cuda.synchronize()
return face_areas_normals

View File

@ -294,58 +294,6 @@ class TestSamplePoints(unittest.TestCase):
return False
return True
@staticmethod
def face_areas(verts, faces):
"""
Vectorized PyTorch implementation of triangle face area function.
"""
verts_faces = verts[faces]
v0x = verts_faces[:, 0::3, 0]
v0y = verts_faces[:, 0::3, 1]
v0z = verts_faces[:, 0::3, 2]
v1x = verts_faces[:, 1::3, 0]
v1y = verts_faces[:, 1::3, 1]
v1z = verts_faces[:, 1::3, 2]
v2x = verts_faces[:, 2::3, 0]
v2y = verts_faces[:, 2::3, 1]
v2z = verts_faces[:, 2::3, 2]
ax = v0x - v2x
ay = v0y - v2y
az = v0z - v2z
bx = v1x - v2x
by = v1y - v2y
bz = v1z - v2z
cx = ay * bz - az * by
cy = az * bx - ax * bz
cz = ax * by - ay * bx
# this gives the area of the parallelogram with sides a and b
area_sqr = cx * cx + cy * cy + cz * cz
# the area of the triangle is half
return torch.sqrt(area_sqr) / 2.0
def test_face_areas(self):
"""
Check the results from face_areas cuda and PyTorch implementions are
the same. Check that face_areas throws an error if cpu tensors are
given as input.
"""
meshes = self.init_meshes(10, 1000, 3000, device="cuda:0")
verts = meshes.verts_packed()
faces = meshes.faces_packed()
areas_torch = self.face_areas(verts, faces).squeeze()
areas_cuda, _ = _C.face_areas_normals(verts, faces)
self.assertTrue(torch.allclose(areas_torch, areas_cuda, atol=5e-8))
with self.assertRaises(Exception) as err:
_C.face_areas_normals(verts.cpu(), faces.cpu())
self.assertTrue("Not implemented on the CPU" in str(err.exception))
@staticmethod
def packed_to_padded_tensor(inputs, first_idxs, max_size):
"""
@ -419,27 +367,6 @@ class TestSamplePoints(unittest.TestCase):
return sample_points
@staticmethod
def face_areas_with_init(
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
):
device = "cuda" if cuda else "cpu"
meshes = TestSamplePoints.init_meshes(
num_meshes, num_verts, num_faces, device
)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
torch.cuda.synchronize()
def face_areas():
if cuda:
_C.face_areas_normals(verts, faces)
else:
TestSamplePoints.face_areas(verts, faces)
torch.cuda.synchronize()
return face_areas
@staticmethod
def packed_to_padded_with_init(
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
@ -453,10 +380,7 @@ class TestSamplePoints(unittest.TestCase):
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
if cuda:
areas, _ = _C.face_areas_normals(verts, faces)
else:
areas = TestSamplePoints.face_areas(verts, faces)
areas, _ = _C.face_areas_normals(verts, faces)
torch.cuda.synchronize()
def packed_to_padded():