diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 6c5b51ab..0555afc9 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -9,7 +9,7 @@ #include "rasterize_points/rasterize_points.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("face_areas_normals", &face_areas_normals); + m.def("face_areas_normals", &FaceAreasNormals); m.def("packed_to_padded_tensor", &packed_to_padded_tensor); m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu index 3b3e4f22..671878f9 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu @@ -4,7 +4,7 @@ #include template -__global__ void face_areas_kernel( +__global__ void FaceAreasNormalsKernel( const scalar_t* __restrict__ verts, const long* __restrict__ faces, scalar_t* __restrict__ face_areas, @@ -55,7 +55,7 @@ __global__ void face_areas_kernel( } } -std::tuple face_areas_cuda( +std::tuple FaceAreasNormalsCuda( at::Tensor verts, at::Tensor faces) { const auto V = verts.size(0); @@ -66,14 +66,15 @@ std::tuple face_areas_cuda( const int blocks = 64; const int threads = 512; - AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_kernel", ([&] { - face_areas_kernel<<>>( - verts.data_ptr(), - faces.data_ptr(), - areas.data_ptr(), - normals.data_ptr(), - V, - F); + AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] { + FaceAreasNormalsKernel + <<>>( + verts.data_ptr(), + faces.data_ptr(), + areas.data_ptr(), + normals.data_ptr(), + V, + F); })); return std::make_tuple(areas, normals); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h index 601fc934..0ef03cc4 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h @@ -16,21 +16,26 @@ // faces[f] // +// Cpu implementation. +std::tuple FaceAreasNormalsCpu( + at::Tensor verts, + at::Tensor faces); + // Cuda implementation. -std::tuple face_areas_cuda( +std::tuple FaceAreasNormalsCuda( at::Tensor verts, at::Tensor faces); // Implementation which is exposed. -std::tuple face_areas_normals( +std::tuple FaceAreasNormals( at::Tensor verts, at::Tensor faces) { if (verts.type().is_cuda() && faces.type().is_cuda()) { #ifdef WITH_CUDA - return face_areas_cuda(verts, faces); + return FaceAreasNormalsCuda(verts, faces); #else AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("Not implemented on the CPU."); + return FaceAreasNormalsCpu(verts, faces); } diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp b/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp new file mode 100644 index 00000000..34b050c4 --- /dev/null +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include + +std::tuple FaceAreasNormalsCpu( + at::Tensor verts, + at::Tensor faces) { + const int V = verts.size(0); + const int F = faces.size(0); + + at::Tensor areas = at::empty({F}, verts.options()); + at::Tensor normals = at::empty({F, 3}, verts.options()); + + auto verts_a = verts.accessor(); + auto faces_a = faces.accessor(); + auto areas_a = areas.accessor(); + auto normals_a = normals.accessor(); + + for (int f = 0; f < F; ++f) { + const int64_t i0 = faces_a[f][0]; + const int64_t i1 = faces_a[f][1]; + const int64_t i2 = faces_a[f][2]; + + const float v0_x = verts_a[i0][0]; + const float v0_y = verts_a[i0][1]; + const float v0_z = verts_a[i0][2]; + + const float v1_x = verts_a[i1][0]; + const float v1_y = verts_a[i1][1]; + const float v1_z = verts_a[i1][2]; + + const float v2_x = verts_a[i2][0]; + const float v2_y = verts_a[i2][1]; + const float v2_z = verts_a[i2][2]; + + const float ax = v1_x - v0_x; + const float ay = v1_y - v0_y; + const float az = v1_z - v0_z; + + const float bx = v2_x - v0_x; + const float by = v2_y - v0_y; + const float bz = v2_z - v0_z; + + const float cx = ay * bz - az * by; + const float cy = az * bx - ax * bz; + const float cz = ax * by - ay * bx; + + float norm = sqrt(cx * cx + cy * cy + cz * cz); + areas_a[f] = norm / 2.0; + norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6) + normals_a[f][0] = cx / norm; + normals_a[f][1] = cy / norm; + normals_a[f][2] = cz / norm; + } + return std::make_tuple(areas, normals); +} diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index fdb6bcb3..ec5d1e37 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -771,21 +771,9 @@ class Meshes(object): return faces_packed = self.faces_packed() verts_packed = self.verts_packed() - if verts_packed.is_cuda and faces_packed.is_cuda: - face_areas, face_normals = _C.face_areas_normals( - verts_packed, faces_packed - ) - else: - vertices_faces = verts_packed[faces_packed] # (F, 3, 3) - # vector pointing from v0 to v1 - v01 = vertices_faces[:, 1] - vertices_faces[:, 0] - # vector pointing from v0 to v2 - v02 = vertices_faces[:, 2] - vertices_faces[:, 0] - normals = torch.cross(v01, v02, dim=1) # (F, 3) - face_areas = normals.norm(dim=-1) / 2 - face_normals = torch.nn.functional.normalize( - normals, p=2, dim=1, eps=1e-6 - ) + face_areas, face_normals = _C.face_areas_normals( + verts_packed, faces_packed + ) self._faces_areas_packed = face_areas self._faces_normals_packed = face_normals diff --git a/tests/bm_face_areas_normals.py b/tests/bm_face_areas_normals.py new file mode 100644 index 00000000..fc3181aa --- /dev/null +++ b/tests/bm_face_areas_normals.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +from itertools import product +import torch +from fvcore.common.benchmark import benchmark + +from test_face_areas_normals import TestFaceAreasNormals + + +def bm_face_areas_normals() -> None: + kwargs_list = [] + backend_cuda = [False] + if torch.cuda.is_available(): + backend_cuda.append(True) + + num_meshes = [2, 10, 32] + num_verts = [100, 1000] + num_faces = [300, 3000] + + test_cases = product(num_meshes, num_verts, num_faces, backend_cuda) + for case in test_cases: + n, v, f, c = case + kwargs_list.append( + {"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c} + ) + benchmark( + TestFaceAreasNormals.face_areas_normals_with_init, + "FACE_AREAS_NORMALS", + kwargs_list, + warmup_iters=1, + ) + + benchmark( + TestFaceAreasNormals.face_areas_normals_with_init_torch, + "FACE_AREAS_NORMALS_TORCH", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_face_areas_normals.py b/tests/test_face_areas_normals.py new file mode 100644 index 00000000..57354254 --- /dev/null +++ b/tests/test_face_areas_normals.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest +import torch + +from pytorch3d import _C +from pytorch3d.structures.meshes import Meshes + +from common_testing import TestCaseMixin + + +class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + @staticmethod + def init_meshes( + num_meshes: int = 10, + num_verts: int = 1000, + num_faces: int = 3000, + device: str = "cpu", + ): + device = torch.device(device) + verts_list = [] + faces_list = [] + for _ in range(num_meshes): + verts = torch.rand( + (num_verts, 3), dtype=torch.float32, device=device + ) + faces = torch.randint( + num_verts, size=(num_faces, 3), dtype=torch.int64, device=device + ) + verts_list.append(verts) + faces_list.append(faces) + meshes = Meshes(verts_list, faces_list) + + return meshes + + @staticmethod + def face_areas_normals(verts, faces): + """ + Pytorch implementation for face areas & normals. + """ + vertices_faces = verts[faces] # (F, 3, 3) + # vector pointing from v0 to v1 + v01 = vertices_faces[:, 1] - vertices_faces[:, 0] + # vector pointing from v0 to v2 + v02 = vertices_faces[:, 2] - vertices_faces[:, 0] + normals = torch.cross(v01, v02, dim=1) # (F, 3) + face_areas = normals.norm(dim=-1) / 2 + face_normals = torch.nn.functional.normalize( + normals, p=2, dim=1, eps=1e-6 + ) + return face_areas, face_normals + + def _test_face_areas_normals_helper(self, device): + """ + Check the results from face_areas cuda/cpp and PyTorch implementation are + the same. + """ + meshes = self.init_meshes(10, 1000, 3000, device=device) + verts = meshes.verts_packed() + faces = meshes.faces_packed() + + areas_torch, normals_torch = self.face_areas_normals(verts, faces) + areas, normals = _C.face_areas_normals(verts, faces) + self.assertClose(areas_torch, areas, atol=1e-7) + # normals get normalized by area thus sensitivity increases as areas + # in our tests can be arbitrarily small. Thus we compare normals after + # multiplying with areas + unnormals = normals * areas.view(-1, 1) + unnormals_torch = normals_torch * areas_torch.view(-1, 1) + self.assertClose(unnormals_torch, unnormals, atol=1e-7) + + def test_face_areas_normals_cpu(self): + self._test_face_areas_normals_helper("cpu") + + def test_face_areas_normals_cuda(self): + self._test_face_areas_normals_helper("cuda:0") + + @staticmethod + def face_areas_normals_with_init( + num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True + ): + device = "cuda:0" if cuda else "cpu" + meshes = TestFaceAreasNormals.init_meshes( + num_meshes, num_verts, num_faces, device + ) + verts = meshes.verts_packed() + faces = meshes.faces_packed() + torch.cuda.synchronize() + + def face_areas_normals(): + _C.face_areas_normals(verts, faces) + torch.cuda.synchronize() + + return face_areas_normals + + @staticmethod + def face_areas_normals_with_init_torch( + num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True + ): + device = "cuda:0" if cuda else "cpu" + meshes = TestFaceAreasNormals.init_meshes( + num_meshes, num_verts, num_faces, device + ) + verts = meshes.verts_packed() + faces = meshes.faces_packed() + torch.cuda.synchronize() + + def face_areas_normals(): + TestFaceAreasNormals.face_areas_normals(verts, faces) + torch.cuda.synchronize() + + return face_areas_normals diff --git a/tests/test_sample_points_from_meshes.py b/tests/test_sample_points_from_meshes.py index 0f633def..d210731b 100644 --- a/tests/test_sample_points_from_meshes.py +++ b/tests/test_sample_points_from_meshes.py @@ -294,58 +294,6 @@ class TestSamplePoints(unittest.TestCase): return False return True - @staticmethod - def face_areas(verts, faces): - """ - Vectorized PyTorch implementation of triangle face area function. - """ - verts_faces = verts[faces] - v0x = verts_faces[:, 0::3, 0] - v0y = verts_faces[:, 0::3, 1] - v0z = verts_faces[:, 0::3, 2] - - v1x = verts_faces[:, 1::3, 0] - v1y = verts_faces[:, 1::3, 1] - v1z = verts_faces[:, 1::3, 2] - - v2x = verts_faces[:, 2::3, 0] - v2y = verts_faces[:, 2::3, 1] - v2z = verts_faces[:, 2::3, 2] - - ax = v0x - v2x - ay = v0y - v2y - az = v0z - v2z - - bx = v1x - v2x - by = v1y - v2y - bz = v1z - v2z - - cx = ay * bz - az * by - cy = az * bx - ax * bz - cz = ax * by - ay * bx - - # this gives the area of the parallelogram with sides a and b - area_sqr = cx * cx + cy * cy + cz * cz - # the area of the triangle is half - return torch.sqrt(area_sqr) / 2.0 - - def test_face_areas(self): - """ - Check the results from face_areas cuda and PyTorch implementions are - the same. Check that face_areas throws an error if cpu tensors are - given as input. - """ - meshes = self.init_meshes(10, 1000, 3000, device="cuda:0") - verts = meshes.verts_packed() - faces = meshes.faces_packed() - - areas_torch = self.face_areas(verts, faces).squeeze() - areas_cuda, _ = _C.face_areas_normals(verts, faces) - self.assertTrue(torch.allclose(areas_torch, areas_cuda, atol=5e-8)) - with self.assertRaises(Exception) as err: - _C.face_areas_normals(verts.cpu(), faces.cpu()) - self.assertTrue("Not implemented on the CPU" in str(err.exception)) - @staticmethod def packed_to_padded_tensor(inputs, first_idxs, max_size): """ @@ -419,27 +367,6 @@ class TestSamplePoints(unittest.TestCase): return sample_points - @staticmethod - def face_areas_with_init( - num_meshes: int, num_verts: int, num_faces: int, cuda: str = True - ): - device = "cuda" if cuda else "cpu" - meshes = TestSamplePoints.init_meshes( - num_meshes, num_verts, num_faces, device - ) - verts = meshes.verts_packed() - faces = meshes.faces_packed() - torch.cuda.synchronize() - - def face_areas(): - if cuda: - _C.face_areas_normals(verts, faces) - else: - TestSamplePoints.face_areas(verts, faces) - torch.cuda.synchronize() - - return face_areas - @staticmethod def packed_to_padded_with_init( num_meshes: int, num_verts: int, num_faces: int, cuda: str = True @@ -453,10 +380,7 @@ class TestSamplePoints(unittest.TestCase): mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() max_faces = meshes.num_faces_per_mesh().max().item() - if cuda: - areas, _ = _C.face_areas_normals(verts, faces) - else: - areas = TestSamplePoints.face_areas(verts, faces) + areas, _ = _C.face_areas_normals(verts, faces) torch.cuda.synchronize() def packed_to_padded():