diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 4d3dd4e2..88993107 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -9,7 +9,8 @@ #include "rasterize_points/rasterize_points.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("face_areas_normals", &FaceAreasNormals); + m.def("face_areas_normals_forward", &FaceAreasNormalsForward); + m.def("face_areas_normals_backward", &FaceAreasNormalsBackward); m.def("packed_to_padded", &PackedToPadded); m.def("padded_to_packed", &PaddedToPacked); m.def("nn_points_idx", &NearestNeighborIdx); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu index 671878f9..6b5c44de 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu @@ -4,9 +4,9 @@ #include template -__global__ void FaceAreasNormalsKernel( +__global__ void FaceAreasNormalsForwardKernel( const scalar_t* __restrict__ verts, - const long* __restrict__ faces, + const int64_t* __restrict__ faces, scalar_t* __restrict__ face_areas, scalar_t* __restrict__ face_normals, const size_t V, @@ -18,9 +18,9 @@ __global__ void FaceAreasNormalsKernel( // Each thread computes the area & normal of its respective faces and adds it // to the global face_areas tensor. for (size_t f = tid; f < F; f += stride) { - const long i0 = faces[3 * f + 0]; - const long i1 = faces[3 * f + 1]; - const long i2 = faces[3 * f + 2]; + const int64_t i0 = faces[3 * f + 0]; + const int64_t i1 = faces[3 * f + 1]; + const int64_t i2 = faces[3 * f + 2]; const scalar_t v0_x = verts[3 * i0 + 0]; const scalar_t v0_y = verts[3 * i0 + 1]; @@ -55,9 +55,161 @@ __global__ void FaceAreasNormalsKernel( } } -std::tuple FaceAreasNormalsCuda( - at::Tensor verts, - at::Tensor faces) { +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. +// Currently, support is for floats only. +__global__ void FaceAreasNormalsBackwardKernel( + const float* __restrict__ grad_areas, + const float* __restrict__ grad_normals, + const float* __restrict__ verts, + const int64_t* __restrict__ faces, + float* __restrict__ grad_verts, + const size_t V, + const size_t F) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = gridDim.x * blockDim.x; + + // Faces split evenly over the number of threads in the grid. + // Each thread computes the area & normal of its respective faces and adds it + // to the global face_areas tensor. + for (size_t f = tid; f < F; f += stride) { + const int64_t i0 = faces[3 * f + 0]; + const int64_t i1 = faces[3 * f + 1]; + const int64_t i2 = faces[3 * f + 2]; + + const float v0_x = verts[3 * i0 + 0]; + const float v0_y = verts[3 * i0 + 1]; + const float v0_z = verts[3 * i0 + 2]; + + const float v1_x = verts[3 * i1 + 0]; + const float v1_y = verts[3 * i1 + 1]; + const float v1_z = verts[3 * i1 + 2]; + + const float v2_x = verts[3 * i2 + 0]; + const float v2_y = verts[3 * i2 + 1]; + const float v2_z = verts[3 * i2 + 2]; + + const float ax = v1_x - v0_x; + const float ay = v1_y - v0_y; + const float az = v1_z - v0_z; + + const float bx = v2_x - v0_x; + const float by = v2_y - v0_y; + const float bz = v2_z - v0_z; + + const float cx = ay * bz - az * by; + const float cy = az * bx - ax * bz; + const float cz = ax * by - ay * bx; + + float norm = sqrt(cx * cx + cy * cy + cz * cz); + norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6) + float inv_norm = 1. / norm; + float inv_norm_2 = pow(inv_norm, 2.0f); + float inv_norm_3 = pow(inv_norm, 3.0f); + + // We compute gradients with respect to the input vertices. + // For each vertex, gradients come from grad_areas and grad_normals. + // eg, grad_v0_x = (d / d v0_x) + // = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x) + // + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x) + // + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x) + // + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x) + // with (d / d areas[f]) = grad_areas[f] and + // (d / d normals[f, j]) = grad_normals[f][j]. + // The equations below are derived after taking + // derivatives wrt to the vertices (fun times!). + + // grad v0 coming from grad areas and grad normals + const float grad_v0_x = + ((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas[f] + + -cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 * + grad_normals[3 * f + 0] + + ((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 1] + + ((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i0 + 0, grad_v0_x); + + const float grad_v0_y = + ((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas[f] + + ((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 0] + + -cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 * + grad_normals[3 * f + 1] + + ((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i0 + 1, grad_v0_y); + + const float grad_v0_z = + ((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas[f] + + ((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 0] + + ((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) * + inv_norm * grad_normals[3 * f + 1] + + -cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 * + grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i0 + 2, grad_v0_z); + + // grad v1 coming from grad areas and grad normals + const float grad_v1_x = + (by * cz - bz * cy) / 2.0 * inv_norm * grad_areas[f] + + -cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals[3 * f + 0] + + (-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 1] + + (by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i1 + 0, grad_v1_x); + + const float grad_v1_y = + (bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas[f] + + (bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 0] + + -cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals[3 * f + 1] + + (-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i1 + 1, grad_v1_y); + + const float grad_v1_z = + (bx * cy - by * cx) / 2.0 * inv_norm * grad_areas[f] + + (-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 0] + + (bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 1] + + -cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i1 + 2, grad_v1_z); + + // grad v2 coming from grad areas + const float grad_v2_x = + (az * cy - ay * cz) / 2.0 * inv_norm * grad_areas[f] + + -cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals[3 * f + 0] + + (az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 1] + + (-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i2 + 0, grad_v2_x); + + const float grad_v2_y = + (ax * cz - az * cx) / 2.0 * inv_norm * grad_areas[f] + + (-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 0] + + -cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals[3 * f + 1] + + (ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i2 + 1, grad_v2_y); + + const float grad_v2_z = + (ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas[f] + + (ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 0] + + (-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm * + grad_normals[3 * f + 1] + + -cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals[3 * f + 2]; + atomicAdd(grad_verts + 3 * i2 + 2, grad_v2_z); + } +} + +std::tuple FaceAreasNormalsForwardCuda( + const at::Tensor verts, + const at::Tensor faces) { const auto V = verts.size(0); const auto F = faces.size(0); @@ -66,16 +218,42 @@ std::tuple FaceAreasNormalsCuda( const int blocks = 64; const int threads = 512; - AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] { - FaceAreasNormalsKernel - <<>>( - verts.data_ptr(), - faces.data_ptr(), - areas.data_ptr(), - normals.data_ptr(), - V, - F); - })); + AT_DISPATCH_FLOATING_TYPES( + verts.type(), "face_areas_normals_forward_cuda", ([&] { + FaceAreasNormalsForwardKernel<<>>( + verts.data_ptr(), + faces.data_ptr(), + areas.data_ptr(), + normals.data_ptr(), + V, + F); + })); return std::make_tuple(areas, normals); } + +at::Tensor FaceAreasNormalsBackwardCuda( + const at::Tensor grad_areas, + const at::Tensor grad_normals, + const at::Tensor verts, + const at::Tensor faces) { + const auto V = verts.size(0); + const auto F = faces.size(0); + + at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options()); + + const int blocks = 64; + const int threads = 512; + // TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports + // doubles. Currently, support is for floats only. + FaceAreasNormalsBackwardKernel<<>>( + grad_areas.data_ptr(), + grad_normals.data_ptr(), + verts.data_ptr(), + faces.data_ptr(), + grad_verts.data_ptr(), + V, + F); + + return grad_verts; +} diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h index 28958407..0617368e 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h @@ -17,27 +17,55 @@ // // Cpu implementation. -std::tuple FaceAreasNormalsCpu( - at::Tensor verts, - at::Tensor faces); +std::tuple FaceAreasNormalsForwardCpu( + const at::Tensor verts, + const at::Tensor faces); +// Cpu implementation +at::Tensor FaceAreasNormalsBackwardCpu( + const at::Tensor grad_areas, + const at::Tensor grad_normals, + const at::Tensor verts, + const at::Tensor faces); #ifdef WITH_CUDA // Cuda implementation. -std::tuple FaceAreasNormalsCuda( - at::Tensor verts, - at::Tensor faces); +std::tuple FaceAreasNormalsForwardCuda( + const at::Tensor verts, + const at::Tensor faces); +// Cuda implementation. +at::Tensor FaceAreasNormalsBackwardCuda( + const at::Tensor grad_areas, + const at::Tensor grad_normals, + const at::Tensor verts, + const at::Tensor faces); #endif // Implementation which is exposed. -std::tuple FaceAreasNormals( - at::Tensor verts, - at::Tensor faces) { +std::tuple FaceAreasNormalsForward( + const at::Tensor verts, + const at::Tensor faces) { if (verts.type().is_cuda() && faces.type().is_cuda()) { #ifdef WITH_CUDA - return FaceAreasNormalsCuda(verts, faces); + return FaceAreasNormalsForwardCuda(verts, faces); #else AT_ERROR("Not compiled with GPU support."); #endif } - return FaceAreasNormalsCpu(verts, faces); + return FaceAreasNormalsForwardCpu(verts, faces); +} + +// Implementation which is exposed. +at::Tensor FaceAreasNormalsBackward( + const at::Tensor grad_areas, + const at::Tensor grad_normals, + const at::Tensor verts, + const at::Tensor faces) { + if (verts.type().is_cuda() && faces.type().is_cuda()) { +#ifdef WITH_CUDA + return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces); } diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp b/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp index f760ec30..09535947 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals_cpu.cpp @@ -3,9 +3,10 @@ #include #include -std::tuple FaceAreasNormalsCpu( - at::Tensor verts, - at::Tensor faces) { +std::tuple FaceAreasNormalsForwardCpu( + const at::Tensor verts, + const at::Tensor faces) { + const int V = verts.size(0); const int F = faces.size(0); at::Tensor areas = at::empty({F}, verts.options()); @@ -54,3 +55,156 @@ std::tuple FaceAreasNormalsCpu( } return std::make_tuple(areas, normals); } + +at::Tensor FaceAreasNormalsBackwardCpu( + const at::Tensor grad_areas, + const at::Tensor grad_normals, + const at::Tensor verts, + const at::Tensor faces) { + const int V = verts.size(0); + const int F = faces.size(0); + + at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options()); + + auto grad_areas_a = grad_areas.accessor(); + auto grad_normals_a = grad_normals.accessor(); + auto verts_a = verts.accessor(); + auto faces_a = faces.accessor(); + auto grad_verts_a = grad_verts.accessor(); + + for (int f = 0; f < F; ++f) { + const int64_t i0 = faces_a[f][0]; + const int64_t i1 = faces_a[f][1]; + const int64_t i2 = faces_a[f][2]; + + const float v0_x = verts_a[i0][0]; + const float v0_y = verts_a[i0][1]; + const float v0_z = verts_a[i0][2]; + + const float v1_x = verts_a[i1][0]; + const float v1_y = verts_a[i1][1]; + const float v1_z = verts_a[i1][2]; + + const float v2_x = verts_a[i2][0]; + const float v2_y = verts_a[i2][1]; + const float v2_z = verts_a[i2][2]; + + const float ax = v1_x - v0_x; + const float ay = v1_y - v0_y; + const float az = v1_z - v0_z; + + const float bx = v2_x - v0_x; + const float by = v2_y - v0_y; + const float bz = v2_z - v0_z; + + const float cx = ay * bz - az * by; + const float cy = az * bx - ax * bz; + const float cz = ax * by - ay * bx; + + float norm = sqrt(cx * cx + cy * cy + cz * cz); + norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6) + float inv_norm = 1. / norm; + float inv_norm_2 = pow(inv_norm, 2.0f); + float inv_norm_3 = pow(inv_norm, 3.0f); + + // We compute gradients with respect to the input vertices. + // For each vertex, gradients come from grad_areas and grad_normals. + // eg, grad_v0_x = (d / d v0_x) + // = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x) + // + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x) + // + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x) + // + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x) + // with (d / d areas[f]) = grad_areas[f] and + // (d / d normals[f, j]) = grad_normals[f][j]. + // The equations below are derived after taking + // derivatives wrt to the vertices (fun times!). + + // grad v0 coming from grad areas and grad normals + const float grad_v0_x = + ((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas_a[f] + + -cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 * + grad_normals_a[f][0] + + ((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) * + inv_norm * grad_normals_a[f][1] + + ((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) * + inv_norm * grad_normals_a[f][2]; + grad_verts_a[i0][0] += grad_v0_x; + + const float grad_v0_y = + ((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas_a[f] + + ((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) * + inv_norm * grad_normals_a[f][0] + + -cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 * + grad_normals_a[f][1] + + ((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) * + inv_norm * grad_normals_a[f][2]; + grad_verts[i0][1] += grad_v0_y; + + const float grad_v0_z = + ((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas_a[f] + + ((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) * + inv_norm * grad_normals_a[f][0] + + ((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) * + inv_norm * grad_normals_a[f][1] + + -cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 * + grad_normals_a[f][2]; + grad_verts[i0][2] += grad_v0_z; + + // grad v1 coming from grad areas and grad normals + const float grad_v1_x = + (by * cz - bz * cy) / 2.0 * inv_norm * grad_areas_a[f] + + -cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals_a[f][0] + + (-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm * + grad_normals_a[f][1] + + (by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm * + grad_normals_a[f][2]; + grad_verts[i1][0] += grad_v1_x; + + const float grad_v1_y = + (bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas_a[f] + + (bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm * + grad_normals_a[f][0] + + -cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals_a[f][1] + + (-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm * + grad_normals_a[f][2]; + grad_verts[i1][1] += grad_v1_y; + + const float grad_v1_z = + (bx * cy - by * cx) / 2.0 * inv_norm * grad_areas_a[f] + + (-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm * + grad_normals_a[f][0] + + (bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm * + grad_normals_a[f][1] + + -cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals_a[f][2]; + grad_verts[i1][2] += grad_v1_z; + + // grad v2 coming from grad areas + const float grad_v2_x = + (az * cy - ay * cz) / 2.0 * inv_norm * grad_areas_a[f] + + -cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals_a[f][0] + + (az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm * + grad_normals_a[f][1] + + (-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm * + grad_normals_a[f][2]; + grad_verts[i2][0] += grad_v2_x; + + const float grad_v2_y = + (ax * cz - az * cx) / 2.0 * inv_norm * grad_areas_a[f] + + (-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm * + grad_normals_a[f][0] + + -cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals_a[f][1] + + (ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm * + grad_normals_a[f][2]; + grad_verts[i2][1] += grad_v2_y; + + const float grad_v2_z = + (ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas_a[f] + + (ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm * + grad_normals_a[f][0] + + (-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm * + grad_normals_a[f][1] + + -cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals_a[f][2]; + grad_verts[i2][2] += grad_v2_z; + } + return grad_verts; +} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 0a0fe2f5..98ca1735 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -3,6 +3,7 @@ from .cubify import cubify from .graph_conv import GraphConv +from .mesh_face_areas_normals import mesh_face_areas_normals from .nearest_neighbor_points import nn_points_idx from .packed_to_padded import packed_to_padded, padded_to_packed from .sample_points_from_meshes import sample_points_from_meshes diff --git a/pytorch3d/ops/mesh_face_areas_normals.py b/pytorch3d/ops/mesh_face_areas_normals.py new file mode 100644 index 00000000..956bf191 --- /dev/null +++ b/pytorch3d/ops/mesh_face_areas_normals.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from pytorch3d import _C + + +class _MeshFaceAreasNormals(Function): + """ + Torch autograd Function wrapper for face areas & normals C++/CUDA implementations. + """ + + @staticmethod + def forward(ctx, verts, faces): + """ + Args: + ctx: Context object used to calculate gradients. + verts: FloatTensor of shape (V, 3), representing the packed + batch verts tensor. + faces: LongTensor of shape (F, 3), representing the packed + batch faces tensor + Returns: + areas: FloatTensor of shape (F,) with the areas of each face + normals: FloatTensor of shape (F,3) with the normals of each face + """ + if not (verts.dim() == 2): + raise ValueError("verts need to be of shape Vx3.") + if not (verts.shape[1] == 3): + raise ValueError("verts need to be of shape Vx3.") + if not (faces.dim() == 2): + raise ValueError("faces need to be of shape Fx3.") + if not (faces.shape[1] == 3): + raise ValueError("faces need to be of shape Fx3.") + if not (faces.dtype == torch.int64): + raise ValueError("faces need to be of type torch.int64.") + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (verts.dtype == torch.float32): + verts = verts.float() + + ctx.save_for_backward(verts, faces) + areas, normals = _C.face_areas_normals_forward(verts, faces) + return areas, normals + + @staticmethod + @once_differentiable + def backward(ctx, grad_areas, grad_normals): + grad_areas = grad_areas.contiguous() + grad_normals = grad_normals.contiguous() + verts, faces = ctx.saved_tensors + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_areas.dtype == torch.float32): + grad_areas = grad_areas.float() + if not (grad_normals.dtype == torch.float32): + grad_normals = grad_normals.float() + grad_verts = _C.face_areas_normals_backward( + grad_areas, grad_normals, verts, faces + ) + return grad_verts, None + + +mesh_face_areas_normals = _MeshFaceAreasNormals.apply diff --git a/pytorch3d/ops/sample_points_from_meshes.py b/pytorch3d/ops/sample_points_from_meshes.py index abe0f25c..c5756623 100644 --- a/pytorch3d/ops/sample_points_from_meshes.py +++ b/pytorch3d/ops/sample_points_from_meshes.py @@ -10,9 +10,8 @@ import sys from typing import Tuple, Union import torch -from pytorch3d import _C - -from .packed_to_padded import packed_to_padded +from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals +from pytorch3d.ops.packed_to_padded import packed_to_padded def sample_points_from_meshes( @@ -53,7 +52,7 @@ def sample_points_from_meshes( # Only compute samples for non empty meshes with torch.no_grad(): - areas, _ = _C.face_areas_normals( + areas, _ = mesh_face_areas_normals( verts, faces ) # Face areas can be zero. max_faces = meshes.num_faces_per_mesh().max().item() diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index ec5d1e37..a0f34a6f 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -4,8 +4,6 @@ from typing import List import torch -from pytorch3d import _C - from . import utils as struct_utils from .textures import Textures @@ -761,6 +759,8 @@ class Meshes(object): refresh: Set to True to force recomputation of face areas. Default: False. """ + from ..ops.mesh_face_areas_normals import mesh_face_areas_normals + if not ( refresh or any( @@ -771,7 +771,7 @@ class Meshes(object): return faces_packed = self.faces_packed() verts_packed = self.verts_packed() - face_areas, face_normals = _C.face_areas_normals( + face_areas, face_normals = mesh_face_areas_normals( verts_packed, faces_packed ) self._faces_areas_packed = face_areas diff --git a/tests/bm_face_areas_normals.py b/tests/bm_face_areas_normals.py index fc3181aa..6aaf3ce1 100644 --- a/tests/bm_face_areas_normals.py +++ b/tests/bm_face_areas_normals.py @@ -11,19 +11,19 @@ from test_face_areas_normals import TestFaceAreasNormals def bm_face_areas_normals() -> None: kwargs_list = [] - backend_cuda = [False] + backend = ["cpu"] if torch.cuda.is_available(): - backend_cuda.append(True) + backend.append("cuda:0") num_meshes = [2, 10, 32] num_verts = [100, 1000] num_faces = [300, 3000] - test_cases = product(num_meshes, num_verts, num_faces, backend_cuda) + test_cases = product(num_meshes, num_verts, num_faces, backend) for case in test_cases: - n, v, f, c = case + n, v, f, d = case kwargs_list.append( - {"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c} + {"num_meshes": n, "num_verts": v, "num_faces": f, "device": d} ) benchmark( TestFaceAreasNormals.face_areas_normals_with_init, diff --git a/tests/test_face_areas_normals.py b/tests/test_face_areas_normals.py index 57354254..496b3555 100644 --- a/tests/test_face_areas_normals.py +++ b/tests/test_face_areas_normals.py @@ -5,7 +5,7 @@ import unittest import torch -from pytorch3d import _C +from pytorch3d.ops import mesh_face_areas_normals from pytorch3d.structures.meshes import Meshes from common_testing import TestCaseMixin @@ -28,7 +28,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): faces_list = [] for _ in range(num_meshes): verts = torch.rand( - (num_verts, 3), dtype=torch.float32, device=device + (num_verts, 3), + dtype=torch.float32, + device=device, + requires_grad=True, ) faces = torch.randint( num_verts, size=(num_faces, 3), dtype=torch.int64, device=device @@ -40,10 +43,12 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): return meshes @staticmethod - def face_areas_normals(verts, faces): + def face_areas_normals_python(verts, faces): """ Pytorch implementation for face areas & normals. """ + # TODO(gkioxari) Change cast to floats once we add support for doubles. + verts = verts.float() vertices_faces = verts[faces] # (F, 3, 3) # vector pointing from v0 to v1 v01 = vertices_faces[:, 1] - vertices_faces[:, 0] @@ -56,24 +61,41 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): ) return face_areas, face_normals - def _test_face_areas_normals_helper(self, device): + def _test_face_areas_normals_helper(self, device, dtype=torch.float32): """ Check the results from face_areas cuda/cpp and PyTorch implementation are the same. """ - meshes = self.init_meshes(10, 1000, 3000, device=device) - verts = meshes.verts_packed() - faces = meshes.faces_packed() + meshes = self.init_meshes(10, 200, 400, device=device) + # make them leaf nodes + verts = meshes.verts_packed().detach().clone().to(dtype) + verts.requires_grad = True + faces = meshes.faces_packed().detach().clone() - areas_torch, normals_torch = self.face_areas_normals(verts, faces) - areas, normals = _C.face_areas_normals(verts, faces) + # forward + areas, normals = mesh_face_areas_normals(verts, faces) + verts_torch = verts.detach().clone().to(dtype) + verts_torch.requires_grad = True + faces_torch = faces.detach().clone() + areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python( + verts_torch, faces_torch + ) self.assertClose(areas_torch, areas, atol=1e-7) # normals get normalized by area thus sensitivity increases as areas # in our tests can be arbitrarily small. Thus we compare normals after # multiplying with areas unnormals = normals * areas.view(-1, 1) unnormals_torch = normals_torch * areas_torch.view(-1, 1) - self.assertClose(unnormals_torch, unnormals, atol=1e-7) + self.assertClose(unnormals_torch, unnormals, atol=1e-6) + + # backward + grad_areas = torch.rand(areas.shape, device=device, dtype=dtype) + grad_normals = torch.rand(normals.shape, device=device, dtype=dtype) + areas.backward((grad_areas, grad_normals)) + grad_verts = verts.grad + areas_torch.backward((grad_areas, grad_normals)) + grad_verts_torch = verts_torch.grad + self.assertClose(grad_verts_torch, grad_verts, atol=1e-6) def test_face_areas_normals_cpu(self): self._test_face_areas_normals_helper("cpu") @@ -81,11 +103,16 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): def test_face_areas_normals_cuda(self): self._test_face_areas_normals_helper("cuda:0") + def test_nonfloats_cpu(self): + self._test_face_areas_normals_helper("cpu", dtype=torch.double) + + def test_nonfloats_cuda(self): + self._test_face_areas_normals_helper("cuda:0", dtype=torch.double) + @staticmethod def face_areas_normals_with_init( - num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True + num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu" ): - device = "cuda:0" if cuda else "cpu" meshes = TestFaceAreasNormals.init_meshes( num_meshes, num_verts, num_faces, device ) @@ -94,16 +121,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): torch.cuda.synchronize() def face_areas_normals(): - _C.face_areas_normals(verts, faces) + mesh_face_areas_normals(verts, faces) torch.cuda.synchronize() return face_areas_normals @staticmethod def face_areas_normals_with_init_torch( - num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True + num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu" ): - device = "cuda:0" if cuda else "cpu" meshes = TestFaceAreasNormals.init_meshes( num_meshes, num_verts, num_faces, device ) @@ -112,7 +138,7 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase): torch.cuda.synchronize() def face_areas_normals(): - TestFaceAreasNormals.face_areas_normals(verts, faces) + TestFaceAreasNormals.face_areas_normals_python(verts, faces) torch.cuda.synchronize() return face_areas_normals diff --git a/tests/test_sample_points_from_meshes.py b/tests/test_sample_points_from_meshes.py index 90758124..46676296 100644 --- a/tests/test_sample_points_from_meshes.py +++ b/tests/test_sample_points_from_meshes.py @@ -6,8 +6,7 @@ import unittest from pathlib import Path import torch -from pytorch3d import _C -from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes +from pytorch3d.ops import sample_points_from_meshes from pytorch3d.structures.meshes import Meshes from pytorch3d.utils.ico_sphere import ico_sphere