mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
face areas backward
Summary: Added backward for mesh face areas & normals. Exposed it as a layer. Replaced the computation with the new op in Meshes and in Sample Points. Current issue: Circular imports. I moved the import of the op in meshes inside the function scope. Reviewed By: jcjohnson Differential Revision: D19920082 fbshipit-source-id: d213226d5e1d19a0c8452f4d32771d07e8b91c0a
This commit is contained in:
parent
9ca5489107
commit
a3baa367e3
@ -9,7 +9,8 @@
|
|||||||
#include "rasterize_points/rasterize_points.h"
|
#include "rasterize_points/rasterize_points.h"
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("face_areas_normals", &FaceAreasNormals);
|
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
||||||
|
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
||||||
m.def("packed_to_padded", &PackedToPadded);
|
m.def("packed_to_padded", &PackedToPadded);
|
||||||
m.def("padded_to_packed", &PaddedToPacked);
|
m.def("padded_to_packed", &PaddedToPacked);
|
||||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void FaceAreasNormalsKernel(
|
__global__ void FaceAreasNormalsForwardKernel(
|
||||||
const scalar_t* __restrict__ verts,
|
const scalar_t* __restrict__ verts,
|
||||||
const long* __restrict__ faces,
|
const int64_t* __restrict__ faces,
|
||||||
scalar_t* __restrict__ face_areas,
|
scalar_t* __restrict__ face_areas,
|
||||||
scalar_t* __restrict__ face_normals,
|
scalar_t* __restrict__ face_normals,
|
||||||
const size_t V,
|
const size_t V,
|
||||||
@ -18,9 +18,9 @@ __global__ void FaceAreasNormalsKernel(
|
|||||||
// Each thread computes the area & normal of its respective faces and adds it
|
// Each thread computes the area & normal of its respective faces and adds it
|
||||||
// to the global face_areas tensor.
|
// to the global face_areas tensor.
|
||||||
for (size_t f = tid; f < F; f += stride) {
|
for (size_t f = tid; f < F; f += stride) {
|
||||||
const long i0 = faces[3 * f + 0];
|
const int64_t i0 = faces[3 * f + 0];
|
||||||
const long i1 = faces[3 * f + 1];
|
const int64_t i1 = faces[3 * f + 1];
|
||||||
const long i2 = faces[3 * f + 2];
|
const int64_t i2 = faces[3 * f + 2];
|
||||||
|
|
||||||
const scalar_t v0_x = verts[3 * i0 + 0];
|
const scalar_t v0_x = verts[3 * i0 + 0];
|
||||||
const scalar_t v0_y = verts[3 * i0 + 1];
|
const scalar_t v0_y = verts[3 * i0 + 1];
|
||||||
@ -55,9 +55,161 @@ __global__ void FaceAreasNormalsKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||||
at::Tensor verts,
|
// Currently, support is for floats only.
|
||||||
at::Tensor faces) {
|
__global__ void FaceAreasNormalsBackwardKernel(
|
||||||
|
const float* __restrict__ grad_areas,
|
||||||
|
const float* __restrict__ grad_normals,
|
||||||
|
const float* __restrict__ verts,
|
||||||
|
const int64_t* __restrict__ faces,
|
||||||
|
float* __restrict__ grad_verts,
|
||||||
|
const size_t V,
|
||||||
|
const size_t F) {
|
||||||
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const size_t stride = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
// Faces split evenly over the number of threads in the grid.
|
||||||
|
// Each thread computes the area & normal of its respective faces and adds it
|
||||||
|
// to the global face_areas tensor.
|
||||||
|
for (size_t f = tid; f < F; f += stride) {
|
||||||
|
const int64_t i0 = faces[3 * f + 0];
|
||||||
|
const int64_t i1 = faces[3 * f + 1];
|
||||||
|
const int64_t i2 = faces[3 * f + 2];
|
||||||
|
|
||||||
|
const float v0_x = verts[3 * i0 + 0];
|
||||||
|
const float v0_y = verts[3 * i0 + 1];
|
||||||
|
const float v0_z = verts[3 * i0 + 2];
|
||||||
|
|
||||||
|
const float v1_x = verts[3 * i1 + 0];
|
||||||
|
const float v1_y = verts[3 * i1 + 1];
|
||||||
|
const float v1_z = verts[3 * i1 + 2];
|
||||||
|
|
||||||
|
const float v2_x = verts[3 * i2 + 0];
|
||||||
|
const float v2_y = verts[3 * i2 + 1];
|
||||||
|
const float v2_z = verts[3 * i2 + 2];
|
||||||
|
|
||||||
|
const float ax = v1_x - v0_x;
|
||||||
|
const float ay = v1_y - v0_y;
|
||||||
|
const float az = v1_z - v0_z;
|
||||||
|
|
||||||
|
const float bx = v2_x - v0_x;
|
||||||
|
const float by = v2_y - v0_y;
|
||||||
|
const float bz = v2_z - v0_z;
|
||||||
|
|
||||||
|
const float cx = ay * bz - az * by;
|
||||||
|
const float cy = az * bx - ax * bz;
|
||||||
|
const float cz = ax * by - ay * bx;
|
||||||
|
|
||||||
|
float norm = sqrt(cx * cx + cy * cy + cz * cz);
|
||||||
|
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
|
||||||
|
float inv_norm = 1. / norm;
|
||||||
|
float inv_norm_2 = pow(inv_norm, 2.0f);
|
||||||
|
float inv_norm_3 = pow(inv_norm, 3.0f);
|
||||||
|
|
||||||
|
// We compute gradients with respect to the input vertices.
|
||||||
|
// For each vertex, gradients come from grad_areas and grad_normals.
|
||||||
|
// eg, grad_v0_x = (d / d v0_x)
|
||||||
|
// = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x)
|
||||||
|
// + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x)
|
||||||
|
// + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x)
|
||||||
|
// + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x)
|
||||||
|
// with (d / d areas[f]) = grad_areas[f] and
|
||||||
|
// (d / d normals[f, j]) = grad_normals[f][j].
|
||||||
|
// The equations below are derived after taking
|
||||||
|
// derivatives wrt to the vertices (fun times!).
|
||||||
|
|
||||||
|
// grad v0 coming from grad areas and grad normals
|
||||||
|
const float grad_v0_x =
|
||||||
|
((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
-cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 *
|
||||||
|
grad_normals[3 * f + 0] +
|
||||||
|
((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 1] +
|
||||||
|
((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i0 + 0, grad_v0_x);
|
||||||
|
|
||||||
|
const float grad_v0_y =
|
||||||
|
((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 0] +
|
||||||
|
-cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 *
|
||||||
|
grad_normals[3 * f + 1] +
|
||||||
|
((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i0 + 1, grad_v0_y);
|
||||||
|
|
||||||
|
const float grad_v0_z =
|
||||||
|
((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 0] +
|
||||||
|
((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals[3 * f + 1] +
|
||||||
|
-cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 *
|
||||||
|
grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i0 + 2, grad_v0_z);
|
||||||
|
|
||||||
|
// grad v1 coming from grad areas and grad normals
|
||||||
|
const float grad_v1_x =
|
||||||
|
(by * cz - bz * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
-cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals[3 * f + 0] +
|
||||||
|
(-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 1] +
|
||||||
|
(by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i1 + 0, grad_v1_x);
|
||||||
|
|
||||||
|
const float grad_v1_y =
|
||||||
|
(bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
(bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 0] +
|
||||||
|
-cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals[3 * f + 1] +
|
||||||
|
(-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i1 + 1, grad_v1_y);
|
||||||
|
|
||||||
|
const float grad_v1_z =
|
||||||
|
(bx * cy - by * cx) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
(-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 0] +
|
||||||
|
(bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 1] +
|
||||||
|
-cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i1 + 2, grad_v1_z);
|
||||||
|
|
||||||
|
// grad v2 coming from grad areas
|
||||||
|
const float grad_v2_x =
|
||||||
|
(az * cy - ay * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
-cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals[3 * f + 0] +
|
||||||
|
(az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 1] +
|
||||||
|
(-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i2 + 0, grad_v2_x);
|
||||||
|
|
||||||
|
const float grad_v2_y =
|
||||||
|
(ax * cz - az * cx) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
(-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 0] +
|
||||||
|
-cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals[3 * f + 1] +
|
||||||
|
(ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i2 + 1, grad_v2_y);
|
||||||
|
|
||||||
|
const float grad_v2_z =
|
||||||
|
(ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||||
|
(ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 0] +
|
||||||
|
(-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals[3 * f + 1] +
|
||||||
|
-cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals[3 * f + 2];
|
||||||
|
atomicAdd(grad_verts + 3 * i2 + 2, grad_v2_z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces) {
|
||||||
const auto V = verts.size(0);
|
const auto V = verts.size(0);
|
||||||
const auto F = faces.size(0);
|
const auto F = faces.size(0);
|
||||||
|
|
||||||
@ -66,16 +218,42 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
|||||||
|
|
||||||
const int blocks = 64;
|
const int blocks = 64;
|
||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
FaceAreasNormalsKernel<scalar_t>
|
verts.type(), "face_areas_normals_forward_cuda", ([&] {
|
||||||
<<<blocks, threads>>>(
|
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads>>>(
|
||||||
verts.data_ptr<scalar_t>(),
|
verts.data_ptr<scalar_t>(),
|
||||||
faces.data_ptr<long>(),
|
faces.data_ptr<int64_t>(),
|
||||||
areas.data_ptr<scalar_t>(),
|
areas.data_ptr<scalar_t>(),
|
||||||
normals.data_ptr<scalar_t>(),
|
normals.data_ptr<scalar_t>(),
|
||||||
V,
|
V,
|
||||||
F);
|
F);
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return std::make_tuple(areas, normals);
|
return std::make_tuple(areas, normals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor FaceAreasNormalsBackwardCuda(
|
||||||
|
const at::Tensor grad_areas,
|
||||||
|
const at::Tensor grad_normals,
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces) {
|
||||||
|
const auto V = verts.size(0);
|
||||||
|
const auto F = faces.size(0);
|
||||||
|
|
||||||
|
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
|
||||||
|
|
||||||
|
const int blocks = 64;
|
||||||
|
const int threads = 512;
|
||||||
|
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||||
|
// doubles. Currently, support is for floats only.
|
||||||
|
FaceAreasNormalsBackwardKernel<<<blocks, threads>>>(
|
||||||
|
grad_areas.data_ptr<float>(),
|
||||||
|
grad_normals.data_ptr<float>(),
|
||||||
|
verts.data_ptr<float>(),
|
||||||
|
faces.data_ptr<int64_t>(),
|
||||||
|
grad_verts.data_ptr<float>(),
|
||||||
|
V,
|
||||||
|
F);
|
||||||
|
|
||||||
|
return grad_verts;
|
||||||
|
}
|
||||||
|
@ -17,27 +17,55 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
// Cpu implementation.
|
// Cpu implementation.
|
||||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCpu(
|
||||||
at::Tensor verts,
|
const at::Tensor verts,
|
||||||
at::Tensor faces);
|
const at::Tensor faces);
|
||||||
|
// Cpu implementation
|
||||||
|
at::Tensor FaceAreasNormalsBackwardCpu(
|
||||||
|
const at::Tensor grad_areas,
|
||||||
|
const at::Tensor grad_normals,
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces);
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
// Cuda implementation.
|
// Cuda implementation.
|
||||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||||
at::Tensor verts,
|
const at::Tensor verts,
|
||||||
at::Tensor faces);
|
const at::Tensor faces);
|
||||||
|
// Cuda implementation.
|
||||||
|
at::Tensor FaceAreasNormalsBackwardCuda(
|
||||||
|
const at::Tensor grad_areas,
|
||||||
|
const at::Tensor grad_normals,
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Implementation which is exposed.
|
// Implementation which is exposed.
|
||||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(
|
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||||
at::Tensor verts,
|
const at::Tensor verts,
|
||||||
at::Tensor faces) {
|
const at::Tensor faces) {
|
||||||
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
return FaceAreasNormalsCuda(verts, faces);
|
return FaceAreasNormalsForwardCuda(verts, faces);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return FaceAreasNormalsCpu(verts, faces);
|
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation which is exposed.
|
||||||
|
at::Tensor FaceAreasNormalsBackward(
|
||||||
|
const at::Tensor grad_areas,
|
||||||
|
const at::Tensor grad_normals,
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces) {
|
||||||
|
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
||||||
|
#ifdef WITH_CUDA
|
||||||
|
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
|
||||||
|
#else
|
||||||
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCpu(
|
||||||
at::Tensor verts,
|
const at::Tensor verts,
|
||||||
at::Tensor faces) {
|
const at::Tensor faces) {
|
||||||
|
const int V = verts.size(0);
|
||||||
const int F = faces.size(0);
|
const int F = faces.size(0);
|
||||||
|
|
||||||
at::Tensor areas = at::empty({F}, verts.options());
|
at::Tensor areas = at::empty({F}, verts.options());
|
||||||
@ -54,3 +55,156 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
|||||||
}
|
}
|
||||||
return std::make_tuple(areas, normals);
|
return std::make_tuple(areas, normals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
at::Tensor FaceAreasNormalsBackwardCpu(
|
||||||
|
const at::Tensor grad_areas,
|
||||||
|
const at::Tensor grad_normals,
|
||||||
|
const at::Tensor verts,
|
||||||
|
const at::Tensor faces) {
|
||||||
|
const int V = verts.size(0);
|
||||||
|
const int F = faces.size(0);
|
||||||
|
|
||||||
|
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
|
||||||
|
|
||||||
|
auto grad_areas_a = grad_areas.accessor<float, 1>();
|
||||||
|
auto grad_normals_a = grad_normals.accessor<float, 2>();
|
||||||
|
auto verts_a = verts.accessor<float, 2>();
|
||||||
|
auto faces_a = faces.accessor<int64_t, 2>();
|
||||||
|
auto grad_verts_a = grad_verts.accessor<float, 2>();
|
||||||
|
|
||||||
|
for (int f = 0; f < F; ++f) {
|
||||||
|
const int64_t i0 = faces_a[f][0];
|
||||||
|
const int64_t i1 = faces_a[f][1];
|
||||||
|
const int64_t i2 = faces_a[f][2];
|
||||||
|
|
||||||
|
const float v0_x = verts_a[i0][0];
|
||||||
|
const float v0_y = verts_a[i0][1];
|
||||||
|
const float v0_z = verts_a[i0][2];
|
||||||
|
|
||||||
|
const float v1_x = verts_a[i1][0];
|
||||||
|
const float v1_y = verts_a[i1][1];
|
||||||
|
const float v1_z = verts_a[i1][2];
|
||||||
|
|
||||||
|
const float v2_x = verts_a[i2][0];
|
||||||
|
const float v2_y = verts_a[i2][1];
|
||||||
|
const float v2_z = verts_a[i2][2];
|
||||||
|
|
||||||
|
const float ax = v1_x - v0_x;
|
||||||
|
const float ay = v1_y - v0_y;
|
||||||
|
const float az = v1_z - v0_z;
|
||||||
|
|
||||||
|
const float bx = v2_x - v0_x;
|
||||||
|
const float by = v2_y - v0_y;
|
||||||
|
const float bz = v2_z - v0_z;
|
||||||
|
|
||||||
|
const float cx = ay * bz - az * by;
|
||||||
|
const float cy = az * bx - ax * bz;
|
||||||
|
const float cz = ax * by - ay * bx;
|
||||||
|
|
||||||
|
float norm = sqrt(cx * cx + cy * cy + cz * cz);
|
||||||
|
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
|
||||||
|
float inv_norm = 1. / norm;
|
||||||
|
float inv_norm_2 = pow(inv_norm, 2.0f);
|
||||||
|
float inv_norm_3 = pow(inv_norm, 3.0f);
|
||||||
|
|
||||||
|
// We compute gradients with respect to the input vertices.
|
||||||
|
// For each vertex, gradients come from grad_areas and grad_normals.
|
||||||
|
// eg, grad_v0_x = (d / d v0_x)
|
||||||
|
// = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x)
|
||||||
|
// + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x)
|
||||||
|
// + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x)
|
||||||
|
// + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x)
|
||||||
|
// with (d / d areas[f]) = grad_areas[f] and
|
||||||
|
// (d / d normals[f, j]) = grad_normals[f][j].
|
||||||
|
// The equations below are derived after taking
|
||||||
|
// derivatives wrt to the vertices (fun times!).
|
||||||
|
|
||||||
|
// grad v0 coming from grad areas and grad normals
|
||||||
|
const float grad_v0_x =
|
||||||
|
((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
-cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 *
|
||||||
|
grad_normals_a[f][0] +
|
||||||
|
((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][1] +
|
||||||
|
((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][2];
|
||||||
|
grad_verts_a[i0][0] += grad_v0_x;
|
||||||
|
|
||||||
|
const float grad_v0_y =
|
||||||
|
((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][0] +
|
||||||
|
-cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 *
|
||||||
|
grad_normals_a[f][1] +
|
||||||
|
((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][2];
|
||||||
|
grad_verts[i0][1] += grad_v0_y;
|
||||||
|
|
||||||
|
const float grad_v0_z =
|
||||||
|
((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][0] +
|
||||||
|
((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||||
|
inv_norm * grad_normals_a[f][1] +
|
||||||
|
-cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 *
|
||||||
|
grad_normals_a[f][2];
|
||||||
|
grad_verts[i0][2] += grad_v0_z;
|
||||||
|
|
||||||
|
// grad v1 coming from grad areas and grad normals
|
||||||
|
const float grad_v1_x =
|
||||||
|
(by * cz - bz * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
-cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals_a[f][0] +
|
||||||
|
(-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][1] +
|
||||||
|
(by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][2];
|
||||||
|
grad_verts[i1][0] += grad_v1_x;
|
||||||
|
|
||||||
|
const float grad_v1_y =
|
||||||
|
(bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
(bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][0] +
|
||||||
|
-cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals_a[f][1] +
|
||||||
|
(-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][2];
|
||||||
|
grad_verts[i1][1] += grad_v1_y;
|
||||||
|
|
||||||
|
const float grad_v1_z =
|
||||||
|
(bx * cy - by * cx) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
(-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][0] +
|
||||||
|
(bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][1] +
|
||||||
|
-cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals_a[f][2];
|
||||||
|
grad_verts[i1][2] += grad_v1_z;
|
||||||
|
|
||||||
|
// grad v2 coming from grad areas
|
||||||
|
const float grad_v2_x =
|
||||||
|
(az * cy - ay * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
-cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals_a[f][0] +
|
||||||
|
(az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][1] +
|
||||||
|
(-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][2];
|
||||||
|
grad_verts[i2][0] += grad_v2_x;
|
||||||
|
|
||||||
|
const float grad_v2_y =
|
||||||
|
(ax * cz - az * cx) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
(-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][0] +
|
||||||
|
-cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals_a[f][1] +
|
||||||
|
(ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][2];
|
||||||
|
grad_verts[i2][1] += grad_v2_y;
|
||||||
|
|
||||||
|
const float grad_v2_z =
|
||||||
|
(ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||||
|
(ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][0] +
|
||||||
|
(-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||||
|
grad_normals_a[f][1] +
|
||||||
|
-cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals_a[f][2];
|
||||||
|
grad_verts[i2][2] += grad_v2_z;
|
||||||
|
}
|
||||||
|
return grad_verts;
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from .cubify import cubify
|
from .cubify import cubify
|
||||||
from .graph_conv import GraphConv
|
from .graph_conv import GraphConv
|
||||||
|
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
from .nearest_neighbor_points import nn_points_idx
|
from .nearest_neighbor_points import nn_points_idx
|
||||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||||
from .sample_points_from_meshes import sample_points_from_meshes
|
from .sample_points_from_meshes import sample_points_from_meshes
|
||||||
|
64
pytorch3d/ops/mesh_face_areas_normals.py
Normal file
64
pytorch3d/ops/mesh_face_areas_normals.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Function
|
||||||
|
from torch.autograd.function import once_differentiable
|
||||||
|
|
||||||
|
from pytorch3d import _C
|
||||||
|
|
||||||
|
|
||||||
|
class _MeshFaceAreasNormals(Function):
|
||||||
|
"""
|
||||||
|
Torch autograd Function wrapper for face areas & normals C++/CUDA implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, verts, faces):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ctx: Context object used to calculate gradients.
|
||||||
|
verts: FloatTensor of shape (V, 3), representing the packed
|
||||||
|
batch verts tensor.
|
||||||
|
faces: LongTensor of shape (F, 3), representing the packed
|
||||||
|
batch faces tensor
|
||||||
|
Returns:
|
||||||
|
areas: FloatTensor of shape (F,) with the areas of each face
|
||||||
|
normals: FloatTensor of shape (F,3) with the normals of each face
|
||||||
|
"""
|
||||||
|
if not (verts.dim() == 2):
|
||||||
|
raise ValueError("verts need to be of shape Vx3.")
|
||||||
|
if not (verts.shape[1] == 3):
|
||||||
|
raise ValueError("verts need to be of shape Vx3.")
|
||||||
|
if not (faces.dim() == 2):
|
||||||
|
raise ValueError("faces need to be of shape Fx3.")
|
||||||
|
if not (faces.shape[1] == 3):
|
||||||
|
raise ValueError("faces need to be of shape Fx3.")
|
||||||
|
if not (faces.dtype == torch.int64):
|
||||||
|
raise ValueError("faces need to be of type torch.int64.")
|
||||||
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
if not (verts.dtype == torch.float32):
|
||||||
|
verts = verts.float()
|
||||||
|
|
||||||
|
ctx.save_for_backward(verts, faces)
|
||||||
|
areas, normals = _C.face_areas_normals_forward(verts, faces)
|
||||||
|
return areas, normals
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@once_differentiable
|
||||||
|
def backward(ctx, grad_areas, grad_normals):
|
||||||
|
grad_areas = grad_areas.contiguous()
|
||||||
|
grad_normals = grad_normals.contiguous()
|
||||||
|
verts, faces = ctx.saved_tensors
|
||||||
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
if not (grad_areas.dtype == torch.float32):
|
||||||
|
grad_areas = grad_areas.float()
|
||||||
|
if not (grad_normals.dtype == torch.float32):
|
||||||
|
grad_normals = grad_normals.float()
|
||||||
|
grad_verts = _C.face_areas_normals_backward(
|
||||||
|
grad_areas, grad_normals, verts, faces
|
||||||
|
)
|
||||||
|
return grad_verts, None
|
||||||
|
|
||||||
|
|
||||||
|
mesh_face_areas_normals = _MeshFaceAreasNormals.apply
|
@ -10,9 +10,8 @@ import sys
|
|||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d import _C
|
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
|
from pytorch3d.ops.packed_to_padded import packed_to_padded
|
||||||
from .packed_to_padded import packed_to_padded
|
|
||||||
|
|
||||||
|
|
||||||
def sample_points_from_meshes(
|
def sample_points_from_meshes(
|
||||||
@ -53,7 +52,7 @@ def sample_points_from_meshes(
|
|||||||
|
|
||||||
# Only compute samples for non empty meshes
|
# Only compute samples for non empty meshes
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
areas, _ = _C.face_areas_normals(
|
areas, _ = mesh_face_areas_normals(
|
||||||
verts, faces
|
verts, faces
|
||||||
) # Face areas can be zero.
|
) # Face areas can be zero.
|
||||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||||
|
@ -4,8 +4,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d import _C
|
|
||||||
|
|
||||||
from . import utils as struct_utils
|
from . import utils as struct_utils
|
||||||
from .textures import Textures
|
from .textures import Textures
|
||||||
|
|
||||||
@ -761,6 +759,8 @@ class Meshes(object):
|
|||||||
refresh: Set to True to force recomputation of face areas.
|
refresh: Set to True to force recomputation of face areas.
|
||||||
Default: False.
|
Default: False.
|
||||||
"""
|
"""
|
||||||
|
from ..ops.mesh_face_areas_normals import mesh_face_areas_normals
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
refresh
|
refresh
|
||||||
or any(
|
or any(
|
||||||
@ -771,7 +771,7 @@ class Meshes(object):
|
|||||||
return
|
return
|
||||||
faces_packed = self.faces_packed()
|
faces_packed = self.faces_packed()
|
||||||
verts_packed = self.verts_packed()
|
verts_packed = self.verts_packed()
|
||||||
face_areas, face_normals = _C.face_areas_normals(
|
face_areas, face_normals = mesh_face_areas_normals(
|
||||||
verts_packed, faces_packed
|
verts_packed, faces_packed
|
||||||
)
|
)
|
||||||
self._faces_areas_packed = face_areas
|
self._faces_areas_packed = face_areas
|
||||||
|
@ -11,19 +11,19 @@ from test_face_areas_normals import TestFaceAreasNormals
|
|||||||
|
|
||||||
def bm_face_areas_normals() -> None:
|
def bm_face_areas_normals() -> None:
|
||||||
kwargs_list = []
|
kwargs_list = []
|
||||||
backend_cuda = [False]
|
backend = ["cpu"]
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
backend_cuda.append(True)
|
backend.append("cuda:0")
|
||||||
|
|
||||||
num_meshes = [2, 10, 32]
|
num_meshes = [2, 10, 32]
|
||||||
num_verts = [100, 1000]
|
num_verts = [100, 1000]
|
||||||
num_faces = [300, 3000]
|
num_faces = [300, 3000]
|
||||||
|
|
||||||
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
|
test_cases = product(num_meshes, num_verts, num_faces, backend)
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
n, v, f, c = case
|
n, v, f, d = case
|
||||||
kwargs_list.append(
|
kwargs_list.append(
|
||||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
|
{"num_meshes": n, "num_verts": v, "num_faces": f, "device": d}
|
||||||
)
|
)
|
||||||
benchmark(
|
benchmark(
|
||||||
TestFaceAreasNormals.face_areas_normals_with_init,
|
TestFaceAreasNormals.face_areas_normals_with_init,
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d import _C
|
from pytorch3d.ops import mesh_face_areas_normals
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
|
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
@ -28,7 +28,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
faces_list = []
|
faces_list = []
|
||||||
for _ in range(num_meshes):
|
for _ in range(num_meshes):
|
||||||
verts = torch.rand(
|
verts = torch.rand(
|
||||||
(num_verts, 3), dtype=torch.float32, device=device
|
(num_verts, 3),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
requires_grad=True,
|
||||||
)
|
)
|
||||||
faces = torch.randint(
|
faces = torch.randint(
|
||||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||||
@ -40,10 +43,12 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
return meshes
|
return meshes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def face_areas_normals(verts, faces):
|
def face_areas_normals_python(verts, faces):
|
||||||
"""
|
"""
|
||||||
Pytorch implementation for face areas & normals.
|
Pytorch implementation for face areas & normals.
|
||||||
"""
|
"""
|
||||||
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
|
verts = verts.float()
|
||||||
vertices_faces = verts[faces] # (F, 3, 3)
|
vertices_faces = verts[faces] # (F, 3, 3)
|
||||||
# vector pointing from v0 to v1
|
# vector pointing from v0 to v1
|
||||||
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
||||||
@ -56,24 +61,41 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
return face_areas, face_normals
|
return face_areas, face_normals
|
||||||
|
|
||||||
def _test_face_areas_normals_helper(self, device):
|
def _test_face_areas_normals_helper(self, device, dtype=torch.float32):
|
||||||
"""
|
"""
|
||||||
Check the results from face_areas cuda/cpp and PyTorch implementation are
|
Check the results from face_areas cuda/cpp and PyTorch implementation are
|
||||||
the same.
|
the same.
|
||||||
"""
|
"""
|
||||||
meshes = self.init_meshes(10, 1000, 3000, device=device)
|
meshes = self.init_meshes(10, 200, 400, device=device)
|
||||||
verts = meshes.verts_packed()
|
# make them leaf nodes
|
||||||
faces = meshes.faces_packed()
|
verts = meshes.verts_packed().detach().clone().to(dtype)
|
||||||
|
verts.requires_grad = True
|
||||||
|
faces = meshes.faces_packed().detach().clone()
|
||||||
|
|
||||||
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
|
# forward
|
||||||
areas, normals = _C.face_areas_normals(verts, faces)
|
areas, normals = mesh_face_areas_normals(verts, faces)
|
||||||
|
verts_torch = verts.detach().clone().to(dtype)
|
||||||
|
verts_torch.requires_grad = True
|
||||||
|
faces_torch = faces.detach().clone()
|
||||||
|
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
|
||||||
|
verts_torch, faces_torch
|
||||||
|
)
|
||||||
self.assertClose(areas_torch, areas, atol=1e-7)
|
self.assertClose(areas_torch, areas, atol=1e-7)
|
||||||
# normals get normalized by area thus sensitivity increases as areas
|
# normals get normalized by area thus sensitivity increases as areas
|
||||||
# in our tests can be arbitrarily small. Thus we compare normals after
|
# in our tests can be arbitrarily small. Thus we compare normals after
|
||||||
# multiplying with areas
|
# multiplying with areas
|
||||||
unnormals = normals * areas.view(-1, 1)
|
unnormals = normals * areas.view(-1, 1)
|
||||||
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
|
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
|
||||||
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
|
self.assertClose(unnormals_torch, unnormals, atol=1e-6)
|
||||||
|
|
||||||
|
# backward
|
||||||
|
grad_areas = torch.rand(areas.shape, device=device, dtype=dtype)
|
||||||
|
grad_normals = torch.rand(normals.shape, device=device, dtype=dtype)
|
||||||
|
areas.backward((grad_areas, grad_normals))
|
||||||
|
grad_verts = verts.grad
|
||||||
|
areas_torch.backward((grad_areas, grad_normals))
|
||||||
|
grad_verts_torch = verts_torch.grad
|
||||||
|
self.assertClose(grad_verts_torch, grad_verts, atol=1e-6)
|
||||||
|
|
||||||
def test_face_areas_normals_cpu(self):
|
def test_face_areas_normals_cpu(self):
|
||||||
self._test_face_areas_normals_helper("cpu")
|
self._test_face_areas_normals_helper("cpu")
|
||||||
@ -81,11 +103,16 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_face_areas_normals_cuda(self):
|
def test_face_areas_normals_cuda(self):
|
||||||
self._test_face_areas_normals_helper("cuda:0")
|
self._test_face_areas_normals_helper("cuda:0")
|
||||||
|
|
||||||
|
def test_nonfloats_cpu(self):
|
||||||
|
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
|
||||||
|
|
||||||
|
def test_nonfloats_cuda(self):
|
||||||
|
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def face_areas_normals_with_init(
|
def face_areas_normals_with_init(
|
||||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
|
||||||
):
|
):
|
||||||
device = "cuda:0" if cuda else "cpu"
|
|
||||||
meshes = TestFaceAreasNormals.init_meshes(
|
meshes = TestFaceAreasNormals.init_meshes(
|
||||||
num_meshes, num_verts, num_faces, device
|
num_meshes, num_verts, num_faces, device
|
||||||
)
|
)
|
||||||
@ -94,16 +121,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def face_areas_normals():
|
def face_areas_normals():
|
||||||
_C.face_areas_normals(verts, faces)
|
mesh_face_areas_normals(verts, faces)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return face_areas_normals
|
return face_areas_normals
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def face_areas_normals_with_init_torch(
|
def face_areas_normals_with_init_torch(
|
||||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
|
||||||
):
|
):
|
||||||
device = "cuda:0" if cuda else "cpu"
|
|
||||||
meshes = TestFaceAreasNormals.init_meshes(
|
meshes = TestFaceAreasNormals.init_meshes(
|
||||||
num_meshes, num_verts, num_faces, device
|
num_meshes, num_verts, num_faces, device
|
||||||
)
|
)
|
||||||
@ -112,7 +138,7 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def face_areas_normals():
|
def face_areas_normals():
|
||||||
TestFaceAreasNormals.face_areas_normals(verts, faces)
|
TestFaceAreasNormals.face_areas_normals_python(verts, faces)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
return face_areas_normals
|
return face_areas_normals
|
||||||
|
@ -6,8 +6,7 @@ import unittest
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d import _C
|
from pytorch3d.ops import sample_points_from_meshes
|
||||||
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user