mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
face areas backward
Summary: Added backward for mesh face areas & normals. Exposed it as a layer. Replaced the computation with the new op in Meshes and in Sample Points. Current issue: Circular imports. I moved the import of the op in meshes inside the function scope. Reviewed By: jcjohnson Differential Revision: D19920082 fbshipit-source-id: d213226d5e1d19a0c8452f4d32771d07e8b91c0a
This commit is contained in:
committed by
Facebook Github Bot
parent
9ca5489107
commit
a3baa367e3
@@ -9,7 +9,8 @@
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals", &FaceAreasNormals);
|
||||
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
||||
m.def("face_areas_normals_backward", &FaceAreasNormalsBackward);
|
||||
m.def("packed_to_padded", &PackedToPadded);
|
||||
m.def("padded_to_packed", &PaddedToPacked);
|
||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include <tuple>
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void FaceAreasNormalsKernel(
|
||||
__global__ void FaceAreasNormalsForwardKernel(
|
||||
const scalar_t* __restrict__ verts,
|
||||
const long* __restrict__ faces,
|
||||
const int64_t* __restrict__ faces,
|
||||
scalar_t* __restrict__ face_areas,
|
||||
scalar_t* __restrict__ face_normals,
|
||||
const size_t V,
|
||||
@@ -18,9 +18,9 @@ __global__ void FaceAreasNormalsKernel(
|
||||
// Each thread computes the area & normal of its respective faces and adds it
|
||||
// to the global face_areas tensor.
|
||||
for (size_t f = tid; f < F; f += stride) {
|
||||
const long i0 = faces[3 * f + 0];
|
||||
const long i1 = faces[3 * f + 1];
|
||||
const long i2 = faces[3 * f + 2];
|
||||
const int64_t i0 = faces[3 * f + 0];
|
||||
const int64_t i1 = faces[3 * f + 1];
|
||||
const int64_t i2 = faces[3 * f + 2];
|
||||
|
||||
const scalar_t v0_x = verts[3 * i0 + 0];
|
||||
const scalar_t v0_y = verts[3 * i0 + 1];
|
||||
@@ -55,9 +55,161 @@ __global__ void FaceAreasNormalsKernel(
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void FaceAreasNormalsBackwardKernel(
|
||||
const float* __restrict__ grad_areas,
|
||||
const float* __restrict__ grad_normals,
|
||||
const float* __restrict__ verts,
|
||||
const int64_t* __restrict__ faces,
|
||||
float* __restrict__ grad_verts,
|
||||
const size_t V,
|
||||
const size_t F) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
// Faces split evenly over the number of threads in the grid.
|
||||
// Each thread computes the area & normal of its respective faces and adds it
|
||||
// to the global face_areas tensor.
|
||||
for (size_t f = tid; f < F; f += stride) {
|
||||
const int64_t i0 = faces[3 * f + 0];
|
||||
const int64_t i1 = faces[3 * f + 1];
|
||||
const int64_t i2 = faces[3 * f + 2];
|
||||
|
||||
const float v0_x = verts[3 * i0 + 0];
|
||||
const float v0_y = verts[3 * i0 + 1];
|
||||
const float v0_z = verts[3 * i0 + 2];
|
||||
|
||||
const float v1_x = verts[3 * i1 + 0];
|
||||
const float v1_y = verts[3 * i1 + 1];
|
||||
const float v1_z = verts[3 * i1 + 2];
|
||||
|
||||
const float v2_x = verts[3 * i2 + 0];
|
||||
const float v2_y = verts[3 * i2 + 1];
|
||||
const float v2_z = verts[3 * i2 + 2];
|
||||
|
||||
const float ax = v1_x - v0_x;
|
||||
const float ay = v1_y - v0_y;
|
||||
const float az = v1_z - v0_z;
|
||||
|
||||
const float bx = v2_x - v0_x;
|
||||
const float by = v2_y - v0_y;
|
||||
const float bz = v2_z - v0_z;
|
||||
|
||||
const float cx = ay * bz - az * by;
|
||||
const float cy = az * bx - ax * bz;
|
||||
const float cz = ax * by - ay * bx;
|
||||
|
||||
float norm = sqrt(cx * cx + cy * cy + cz * cz);
|
||||
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
|
||||
float inv_norm = 1. / norm;
|
||||
float inv_norm_2 = pow(inv_norm, 2.0f);
|
||||
float inv_norm_3 = pow(inv_norm, 3.0f);
|
||||
|
||||
// We compute gradients with respect to the input vertices.
|
||||
// For each vertex, gradients come from grad_areas and grad_normals.
|
||||
// eg, grad_v0_x = (d / d v0_x)
|
||||
// = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x)
|
||||
// + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x)
|
||||
// + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x)
|
||||
// + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x)
|
||||
// with (d / d areas[f]) = grad_areas[f] and
|
||||
// (d / d normals[f, j]) = grad_normals[f][j].
|
||||
// The equations below are derived after taking
|
||||
// derivatives wrt to the vertices (fun times!).
|
||||
|
||||
// grad v0 coming from grad areas and grad normals
|
||||
const float grad_v0_x =
|
||||
((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||
-cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 *
|
||||
grad_normals[3 * f + 0] +
|
||||
((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 1] +
|
||||
((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i0 + 0, grad_v0_x);
|
||||
|
||||
const float grad_v0_y =
|
||||
((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||
((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 0] +
|
||||
-cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 *
|
||||
grad_normals[3 * f + 1] +
|
||||
((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i0 + 1, grad_v0_y);
|
||||
|
||||
const float grad_v0_z =
|
||||
((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||
((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 0] +
|
||||
((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||
inv_norm * grad_normals[3 * f + 1] +
|
||||
-cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 *
|
||||
grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i0 + 2, grad_v0_z);
|
||||
|
||||
// grad v1 coming from grad areas and grad normals
|
||||
const float grad_v1_x =
|
||||
(by * cz - bz * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||
-cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals[3 * f + 0] +
|
||||
(-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 1] +
|
||||
(by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i1 + 0, grad_v1_x);
|
||||
|
||||
const float grad_v1_y =
|
||||
(bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||
(bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 0] +
|
||||
-cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals[3 * f + 1] +
|
||||
(-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i1 + 1, grad_v1_y);
|
||||
|
||||
const float grad_v1_z =
|
||||
(bx * cy - by * cx) / 2.0 * inv_norm * grad_areas[f] +
|
||||
(-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 0] +
|
||||
(bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 1] +
|
||||
-cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i1 + 2, grad_v1_z);
|
||||
|
||||
// grad v2 coming from grad areas
|
||||
const float grad_v2_x =
|
||||
(az * cy - ay * cz) / 2.0 * inv_norm * grad_areas[f] +
|
||||
-cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals[3 * f + 0] +
|
||||
(az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 1] +
|
||||
(-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i2 + 0, grad_v2_x);
|
||||
|
||||
const float grad_v2_y =
|
||||
(ax * cz - az * cx) / 2.0 * inv_norm * grad_areas[f] +
|
||||
(-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 0] +
|
||||
-cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals[3 * f + 1] +
|
||||
(ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i2 + 1, grad_v2_y);
|
||||
|
||||
const float grad_v2_z =
|
||||
(ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas[f] +
|
||||
(ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 0] +
|
||||
(-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals[3 * f + 1] +
|
||||
-cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals[3 * f + 2];
|
||||
atomicAdd(grad_verts + 3 * i2 + 2, grad_v2_z);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
const auto V = verts.size(0);
|
||||
const auto F = faces.size(0);
|
||||
|
||||
@@ -66,16 +218,42 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
AT_DISPATCH_FLOATING_TYPES(verts.type(), "face_areas_normals_cuda", ([&] {
|
||||
FaceAreasNormalsKernel<scalar_t>
|
||||
<<<blocks, threads>>>(
|
||||
verts.data_ptr<scalar_t>(),
|
||||
faces.data_ptr<long>(),
|
||||
areas.data_ptr<scalar_t>(),
|
||||
normals.data_ptr<scalar_t>(),
|
||||
V,
|
||||
F);
|
||||
}));
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
verts.type(), "face_areas_normals_forward_cuda", ([&] {
|
||||
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads>>>(
|
||||
verts.data_ptr<scalar_t>(),
|
||||
faces.data_ptr<int64_t>(),
|
||||
areas.data_ptr<scalar_t>(),
|
||||
normals.data_ptr<scalar_t>(),
|
||||
V,
|
||||
F);
|
||||
}));
|
||||
|
||||
return std::make_tuple(areas, normals);
|
||||
}
|
||||
|
||||
at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
const at::Tensor grad_areas,
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
const auto V = verts.size(0);
|
||||
const auto F = faces.size(0);
|
||||
|
||||
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
|
||||
|
||||
const int blocks = 64;
|
||||
const int threads = 512;
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
FaceAreasNormalsBackwardKernel<<<blocks, threads>>>(
|
||||
grad_areas.data_ptr<float>(),
|
||||
grad_normals.data_ptr<float>(),
|
||||
verts.data_ptr<float>(),
|
||||
faces.data_ptr<int64_t>(),
|
||||
grad_verts.data_ptr<float>(),
|
||||
V,
|
||||
F);
|
||||
|
||||
return grad_verts;
|
||||
}
|
||||
|
||||
@@ -17,27 +17,55 @@
|
||||
//
|
||||
|
||||
// Cpu implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCpu(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces);
|
||||
// Cpu implementation
|
||||
at::Tensor FaceAreasNormalsBackwardCpu(
|
||||
const at::Tensor grad_areas,
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
// Cuda implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces);
|
||||
// Cuda implementation.
|
||||
at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
const at::Tensor grad_areas,
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces);
|
||||
#endif
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return FaceAreasNormalsCuda(verts, faces);
|
||||
return FaceAreasNormalsForwardCuda(verts, faces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return FaceAreasNormalsCpu(verts, faces);
|
||||
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||
}
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor FaceAreasNormalsBackward(
|
||||
const at::Tensor grad_areas,
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
if (verts.type().is_cuda() && faces.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||
}
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces) {
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCpu(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
const int V = verts.size(0);
|
||||
const int F = faces.size(0);
|
||||
|
||||
at::Tensor areas = at::empty({F}, verts.options());
|
||||
@@ -54,3 +55,156 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
}
|
||||
return std::make_tuple(areas, normals);
|
||||
}
|
||||
|
||||
at::Tensor FaceAreasNormalsBackwardCpu(
|
||||
const at::Tensor grad_areas,
|
||||
const at::Tensor grad_normals,
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
const int V = verts.size(0);
|
||||
const int F = faces.size(0);
|
||||
|
||||
at::Tensor grad_verts = at::zeros({V, 3}, grad_areas.options());
|
||||
|
||||
auto grad_areas_a = grad_areas.accessor<float, 1>();
|
||||
auto grad_normals_a = grad_normals.accessor<float, 2>();
|
||||
auto verts_a = verts.accessor<float, 2>();
|
||||
auto faces_a = faces.accessor<int64_t, 2>();
|
||||
auto grad_verts_a = grad_verts.accessor<float, 2>();
|
||||
|
||||
for (int f = 0; f < F; ++f) {
|
||||
const int64_t i0 = faces_a[f][0];
|
||||
const int64_t i1 = faces_a[f][1];
|
||||
const int64_t i2 = faces_a[f][2];
|
||||
|
||||
const float v0_x = verts_a[i0][0];
|
||||
const float v0_y = verts_a[i0][1];
|
||||
const float v0_z = verts_a[i0][2];
|
||||
|
||||
const float v1_x = verts_a[i1][0];
|
||||
const float v1_y = verts_a[i1][1];
|
||||
const float v1_z = verts_a[i1][2];
|
||||
|
||||
const float v2_x = verts_a[i2][0];
|
||||
const float v2_y = verts_a[i2][1];
|
||||
const float v2_z = verts_a[i2][2];
|
||||
|
||||
const float ax = v1_x - v0_x;
|
||||
const float ay = v1_y - v0_y;
|
||||
const float az = v1_z - v0_z;
|
||||
|
||||
const float bx = v2_x - v0_x;
|
||||
const float by = v2_y - v0_y;
|
||||
const float bz = v2_z - v0_z;
|
||||
|
||||
const float cx = ay * bz - az * by;
|
||||
const float cy = az * bx - ax * bz;
|
||||
const float cz = ax * by - ay * bx;
|
||||
|
||||
float norm = sqrt(cx * cx + cy * cy + cz * cz);
|
||||
norm = (norm < 1e-6) ? 1e-6 : norm; // max(norm, 1e-6)
|
||||
float inv_norm = 1. / norm;
|
||||
float inv_norm_2 = pow(inv_norm, 2.0f);
|
||||
float inv_norm_3 = pow(inv_norm, 3.0f);
|
||||
|
||||
// We compute gradients with respect to the input vertices.
|
||||
// For each vertex, gradients come from grad_areas and grad_normals.
|
||||
// eg, grad_v0_x = (d / d v0_x)
|
||||
// = \sum_f (d / d areas[f]) * (d areas[f] / d v0_x)
|
||||
// + (d / d normals[f, 0]) * (d normals[f, 0] / d v0_x)
|
||||
// + (d / d normals[f, 1]) * (d normals[f, 1] / d v0_x)
|
||||
// + (d / d normals[f, 2]) * (d normals[f, 2] / d v0_x)
|
||||
// with (d / d areas[f]) = grad_areas[f] and
|
||||
// (d / d normals[f, j]) = grad_normals[f][j].
|
||||
// The equations below are derived after taking
|
||||
// derivatives wrt to the vertices (fun times!).
|
||||
|
||||
// grad v0 coming from grad areas and grad normals
|
||||
const float grad_v0_x =
|
||||
((-az + bz) * cy + (-by + ay) * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
-cx * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_3 *
|
||||
grad_normals_a[f][0] +
|
||||
((-az + bz) - cy * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][1] +
|
||||
((-by + ay) - cz * ((-az + bz) * cy + (-by + ay) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][2];
|
||||
grad_verts_a[i0][0] += grad_v0_x;
|
||||
|
||||
const float grad_v0_y =
|
||||
((-bz + az) * cx + (-ax + bx) * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
((-bz + az) - cx * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][0] +
|
||||
-cy * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_3 *
|
||||
grad_normals_a[f][1] +
|
||||
((-ax + bx) - cz * ((-bz + az) * cx + (-ax + bx) * cz) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][2];
|
||||
grad_verts[i0][1] += grad_v0_y;
|
||||
|
||||
const float grad_v0_z =
|
||||
((-ay + by) * cx + (-bx + ax) * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
((-ay + by) - cx * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][0] +
|
||||
((-bx + ax) - cy * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_2) *
|
||||
inv_norm * grad_normals_a[f][1] +
|
||||
-cz * ((-ay + by) * cx + (-bx + ax) * cy) * inv_norm_3 *
|
||||
grad_normals_a[f][2];
|
||||
grad_verts[i0][2] += grad_v0_z;
|
||||
|
||||
// grad v1 coming from grad areas and grad normals
|
||||
const float grad_v1_x =
|
||||
(by * cz - bz * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
-cx * (by * cz - bz * cy) * inv_norm_3 * grad_normals_a[f][0] +
|
||||
(-bz - cy * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][1] +
|
||||
(by - cz * (by * cz - bz * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][2];
|
||||
grad_verts[i1][0] += grad_v1_x;
|
||||
|
||||
const float grad_v1_y =
|
||||
(bz * cx - bx * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
(bz - cx * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][0] +
|
||||
-cy * (bz * cx - bx * cz) * inv_norm_3 * grad_normals_a[f][1] +
|
||||
(-bx - cz * (bz * cx - bx * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][2];
|
||||
grad_verts[i1][1] += grad_v1_y;
|
||||
|
||||
const float grad_v1_z =
|
||||
(bx * cy - by * cx) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
(-by - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][0] +
|
||||
(bx - cx * (bx * cy - by * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][1] +
|
||||
-cz * (bx * cy - by * cx) * inv_norm_3 * grad_normals_a[f][2];
|
||||
grad_verts[i1][2] += grad_v1_z;
|
||||
|
||||
// grad v2 coming from grad areas
|
||||
const float grad_v2_x =
|
||||
(az * cy - ay * cz) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
-cx * (az * cy - ay * cz) * inv_norm_3 * grad_normals_a[f][0] +
|
||||
(az - cy * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][1] +
|
||||
(-ay - cz * (az * cy - ay * cz) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][2];
|
||||
grad_verts[i2][0] += grad_v2_x;
|
||||
|
||||
const float grad_v2_y =
|
||||
(ax * cz - az * cx) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
(-az - cx * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][0] +
|
||||
-cy * (ax * cz - az * cx) * inv_norm_3 * grad_normals_a[f][1] +
|
||||
(ax - cz * (ax * cz - az * cx) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][2];
|
||||
grad_verts[i2][1] += grad_v2_y;
|
||||
|
||||
const float grad_v2_z =
|
||||
(ay * cx - ax * cy) / 2.0 * inv_norm * grad_areas_a[f] +
|
||||
(ay - cx * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][0] +
|
||||
(-ax - cy * (ay * cx - ax * cy) * inv_norm_2) * inv_norm *
|
||||
grad_normals_a[f][1] +
|
||||
-cz * (ay * cx - ax * cy) * inv_norm_3 * grad_normals_a[f][2];
|
||||
grad_verts[i2][2] += grad_v2_z;
|
||||
}
|
||||
return grad_verts;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user