mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Consolidate mesh backward kernels
Summary: This diff creates the generic MeshBackwardKernel which can handle distance calculations between point, edge and faces in either direction. Replaces only point_mesh_face code for now. Reviewed By: gkioxari Differential Revision: D24549374 fbshipit-source-id: 2853c1da1c2a6b6de8d0e40007ba0735b8959044
This commit is contained in:
parent
c41aff23f0
commit
8dcfe30f66
@ -212,86 +212,135 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
|||||||
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
|
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void PointFaceBackwardKernel(
|
__global__ void DistanceBackwardKernel(
|
||||||
const float* __restrict__ points, // (P, 3)
|
const float* __restrict__ objects, // (O * oD * 3)
|
||||||
const float* __restrict__ tris, // (T, 3, 3)
|
const size_t objects_size, // O
|
||||||
const int64_t* __restrict__ idx_points, // (P,)
|
const size_t objects_dim, // oD
|
||||||
const float* __restrict__ grad_dists, // (P,)
|
const float* __restrict__ targets, // (T * tD * 3)
|
||||||
float* __restrict__ grad_points, // (P, 3)
|
const size_t targets_dim, // tD
|
||||||
float* __restrict__ grad_tris, // (T, 3, 3)
|
const int64_t* __restrict__ idx_objects, // (O,)
|
||||||
const size_t P) {
|
const float* __restrict__ grad_dists, // (O,)
|
||||||
float3* points_f3 = (float3*)points;
|
float* __restrict__ grad_points, // ((O or T) * 3)
|
||||||
float3* tris_f3 = (float3*)tris;
|
float* __restrict__ grad_face) { // ((O or T) * max(oD, tD) * 3)
|
||||||
|
// This kernel is used interchangeably to compute bi-directional backward
|
||||||
|
// distances between points and triangles/lines. The direction of the distance
|
||||||
|
// computed, i.e. point to triangle/line or triangle/line to point, depends on
|
||||||
|
// the order of the input arguments and is inferred based on their shape. The
|
||||||
|
// shape is used to distinguish between triangles and lines. Note that
|
||||||
|
// grad_points will always be used for the point data and grad_face for the
|
||||||
|
// edge/triangle
|
||||||
|
|
||||||
|
// Set references to points/face based on whether objects/targets are which
|
||||||
|
float3* points_f3 = objects_dim == 1 ? (float3*)objects : (float3*)targets;
|
||||||
|
float3* face_f3 = objects_dim == 1 ? (float3*)targets : (float3*)objects;
|
||||||
|
|
||||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const size_t stride = gridDim.x * blockDim.x;
|
const size_t stride = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (size_t p = tid; p < P; p += stride) {
|
for (size_t o = tid; o < objects_size; o += stride) {
|
||||||
const float3 p_f3 = points_f3[p];
|
const int64_t tidx = idx_objects[o];
|
||||||
|
|
||||||
const int64_t tidx = idx_points[p];
|
size_t point_index = objects_dim == 1 ? o : tidx;
|
||||||
const float3 v0 = tris_f3[tidx * 3 + 0];
|
size_t face_index = objects_dim == 1 ? tidx * targets_dim : o * objects_dim;
|
||||||
const float3 v1 = tris_f3[tidx * 3 + 1];
|
bool isTriangle = objects_dim == 3 || targets_dim == 3;
|
||||||
const float3 v2 = tris_f3[tidx * 3 + 2];
|
|
||||||
|
|
||||||
const float grad_dist = grad_dists[p];
|
float3 grad_point, grad_v0, grad_v1, grad_v2;
|
||||||
|
if (isTriangle) {
|
||||||
|
const auto grads = PointTriangle3DistanceBackward(
|
||||||
|
points_f3[point_index],
|
||||||
|
face_f3[face_index],
|
||||||
|
face_f3[face_index + 1],
|
||||||
|
face_f3[face_index + 2],
|
||||||
|
grad_dists[o]);
|
||||||
|
grad_point = thrust::get<0>(grads);
|
||||||
|
grad_v0 = thrust::get<1>(grads);
|
||||||
|
grad_v1 = thrust::get<2>(grads);
|
||||||
|
grad_v2 = thrust::get<3>(grads);
|
||||||
|
} else {
|
||||||
|
const auto grads = PointLine3DistanceBackward(
|
||||||
|
points_f3[point_index],
|
||||||
|
face_f3[face_index],
|
||||||
|
face_f3[face_index + 1],
|
||||||
|
grad_dists[o]);
|
||||||
|
grad_point = thrust::get<0>(grads);
|
||||||
|
grad_v0 = thrust::get<1>(grads);
|
||||||
|
grad_v1 = thrust::get<2>(grads);
|
||||||
|
}
|
||||||
|
|
||||||
const auto grads =
|
atomicAdd(grad_points + point_index * 3 + 0, grad_point.x);
|
||||||
PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist);
|
atomicAdd(grad_points + point_index * 3 + 1, grad_point.y);
|
||||||
const float3 grad_point = thrust::get<0>(grads);
|
atomicAdd(grad_points + point_index * 3 + 2, grad_point.z);
|
||||||
const float3 grad_v0 = thrust::get<1>(grads);
|
|
||||||
const float3 grad_v1 = thrust::get<2>(grads);
|
|
||||||
const float3 grad_v2 = thrust::get<3>(grads);
|
|
||||||
|
|
||||||
atomicAdd(grad_points + p * 3 + 0, grad_point.x);
|
atomicAdd(grad_face + face_index * 3 + 0 * 3 + 0, grad_v0.x);
|
||||||
atomicAdd(grad_points + p * 3 + 1, grad_point.y);
|
atomicAdd(grad_face + face_index * 3 + 0 * 3 + 1, grad_v0.y);
|
||||||
atomicAdd(grad_points + p * 3 + 2, grad_point.z);
|
atomicAdd(grad_face + face_index * 3 + 0 * 3 + 2, grad_v0.z);
|
||||||
|
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 0, grad_v0.x);
|
atomicAdd(grad_face + face_index * 3 + 1 * 3 + 0, grad_v1.x);
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 1, grad_v0.y);
|
atomicAdd(grad_face + face_index * 3 + 1 * 3 + 1, grad_v1.y);
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 2, grad_v0.z);
|
atomicAdd(grad_face + face_index * 3 + 1 * 3 + 2, grad_v1.z);
|
||||||
|
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 0, grad_v1.x);
|
if (isTriangle) {
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 1, grad_v1.y);
|
atomicAdd(grad_face + face_index * 3 + 2 * 3 + 0, grad_v2.x);
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 2, grad_v1.z);
|
atomicAdd(grad_face + face_index * 3 + 2 * 3 + 1, grad_v2.y);
|
||||||
|
atomicAdd(grad_face + face_index * 3 + 2 * 3 + 2, grad_v2.z);
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 0, grad_v2.x);
|
}
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 1, grad_v2.y);
|
|
||||||
atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 2, grad_v2.z);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
|
||||||
const at::Tensor& points,
|
const at::Tensor& objects,
|
||||||
const at::Tensor& tris,
|
const size_t objects_dim,
|
||||||
const at::Tensor& idx_points,
|
const at::Tensor& targets,
|
||||||
|
const size_t targets_dim,
|
||||||
|
const at::Tensor& idx_objects,
|
||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
at::TensorArg points_t{points, "points", 1},
|
at::TensorArg objects_t{objects, "objects", 1},
|
||||||
idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3},
|
targets_t{targets, "targets", 2},
|
||||||
|
idx_objects_t{idx_objects, "idx_objects", 3},
|
||||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||||
at::CheckedFrom c = "PointFaceDistanceBackwardCuda";
|
at::CheckedFrom c = "DistanceBackwardCuda";
|
||||||
at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t});
|
at::checkAllSameGPU(c, {objects_t, targets_t, idx_objects_t, grad_dists_t});
|
||||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
at::checkAllSameType(c, {objects_t, targets_t, grad_dists_t});
|
||||||
|
|
||||||
// Set the device for the kernel launch based on the device of the input
|
// Set the device for the kernel launch based on the device of the input
|
||||||
at::cuda::CUDAGuard device_guard(points.device());
|
at::cuda::CUDAGuard device_guard(objects.device());
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
const int64_t P = points.size(0);
|
const int64_t objects_size = objects.size(0);
|
||||||
const int64_t T = tris.size(0);
|
const int64_t targets_size = targets.size(0);
|
||||||
|
|
||||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
at::Tensor grad_points;
|
||||||
TORCH_CHECK(
|
at::Tensor grad_tris;
|
||||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
|
||||||
"tris must be of shape Tx3x3");
|
|
||||||
TORCH_CHECK(idx_points.size(0) == P);
|
|
||||||
TORCH_CHECK(grad_dists.size(0) == P);
|
|
||||||
|
|
||||||
// clang-format off
|
TORCH_CHECK(idx_objects.size(0) == objects_size);
|
||||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
TORCH_CHECK(grad_dists.size(0) == objects_size);
|
||||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
if (objects_dim == 1) {
|
||||||
// clang-format on
|
TORCH_CHECK(
|
||||||
|
targets_dim >= 2 && targets_dim <= 3,
|
||||||
|
"either object or target must be edge or face");
|
||||||
|
TORCH_CHECK(objects.size(1) == 3, "points must be of shape Px3");
|
||||||
|
TORCH_CHECK(
|
||||||
|
targets.size(2) == 3,
|
||||||
|
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
|
||||||
|
// clang-format off
|
||||||
|
grad_points = at::zeros({objects_size, 3}, objects.options());
|
||||||
|
grad_tris = at::zeros({targets_size, int64_t(targets_dim), 3}, targets.options());
|
||||||
|
// clang-format on
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(targets_dim == 1, "either object or target must be point");
|
||||||
|
TORCH_CHECK(
|
||||||
|
objects_dim >= 2 && objects_dim <= 3,
|
||||||
|
"either object or target must be edge or face");
|
||||||
|
TORCH_CHECK(targets.size(1) == 3, "points must be of shape Px3");
|
||||||
|
TORCH_CHECK(
|
||||||
|
objects.size(2) == 3,
|
||||||
|
"face must be of shape Tx3x3, lines must be of shape Tx2x3");
|
||||||
|
// clang-format off
|
||||||
|
grad_points = at::zeros({targets_size, 3}, targets.options());
|
||||||
|
grad_tris = at::zeros({objects_size, int64_t(objects_dim), 3}, objects.options());
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
@ -301,19 +350,29 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
|||||||
const int blocks = 64;
|
const int blocks = 64;
|
||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
|
DistanceBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.contiguous().data_ptr<float>(),
|
objects.contiguous().data_ptr<float>(),
|
||||||
tris.contiguous().data_ptr<float>(),
|
objects_size,
|
||||||
idx_points.contiguous().data_ptr<int64_t>(),
|
objects_dim,
|
||||||
|
targets.contiguous().data_ptr<float>(),
|
||||||
|
targets_dim,
|
||||||
|
idx_objects.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.contiguous().data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_tris.data_ptr<float>(),
|
grad_tris.data_ptr<float>());
|
||||||
P);
|
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return std::make_tuple(grad_points, grad_tris);
|
return std::make_tuple(grad_points, grad_tris);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
||||||
|
const at::Tensor& points,
|
||||||
|
const at::Tensor& tris,
|
||||||
|
const at::Tensor& idx_points,
|
||||||
|
const at::Tensor& grad_dists) {
|
||||||
|
return DistanceBackwardCuda(points, 1, tris, 3, idx_points, grad_dists);
|
||||||
|
}
|
||||||
|
|
||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
// * FacePointDistance *
|
// * FacePointDistance *
|
||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
@ -328,107 +387,12 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
|||||||
tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris);
|
tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris);
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void FacePointBackwardKernel(
|
|
||||||
const float* __restrict__ points, // (P, 3)
|
|
||||||
const float* __restrict__ tris, // (T, 3, 3)
|
|
||||||
const int64_t* __restrict__ idx_tris, // (T,)
|
|
||||||
const float* __restrict__ grad_dists, // (T,)
|
|
||||||
float* __restrict__ grad_points, // (P, 3)
|
|
||||||
float* __restrict__ grad_tris, // (T, 3, 3)
|
|
||||||
const size_t T) {
|
|
||||||
float3* points_f3 = (float3*)points;
|
|
||||||
float3* tris_f3 = (float3*)tris;
|
|
||||||
|
|
||||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
const size_t stride = gridDim.x * blockDim.x;
|
|
||||||
|
|
||||||
for (size_t t = tid; t < T; t += stride) {
|
|
||||||
const float3 v0 = tris_f3[t * 3 + 0];
|
|
||||||
const float3 v1 = tris_f3[t * 3 + 1];
|
|
||||||
const float3 v2 = tris_f3[t * 3 + 2];
|
|
||||||
|
|
||||||
const int64_t pidx = idx_tris[t];
|
|
||||||
|
|
||||||
const float3 p_f3 = points_f3[pidx];
|
|
||||||
|
|
||||||
const float grad_dist = grad_dists[t];
|
|
||||||
|
|
||||||
const auto grads =
|
|
||||||
PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist);
|
|
||||||
const float3 grad_point = thrust::get<0>(grads);
|
|
||||||
const float3 grad_v0 = thrust::get<1>(grads);
|
|
||||||
const float3 grad_v1 = thrust::get<2>(grads);
|
|
||||||
const float3 grad_v2 = thrust::get<3>(grads);
|
|
||||||
|
|
||||||
atomicAdd(grad_points + pidx * 3 + 0, grad_point.x);
|
|
||||||
atomicAdd(grad_points + pidx * 3 + 1, grad_point.y);
|
|
||||||
atomicAdd(grad_points + pidx * 3 + 2, grad_point.z);
|
|
||||||
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 0, grad_v0.x);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 1, grad_v0.y);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 2, grad_v0.z);
|
|
||||||
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 0, grad_v1.x);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 1, grad_v1.y);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 2, grad_v1.z);
|
|
||||||
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 0, grad_v2.x);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 1, grad_v2.y);
|
|
||||||
atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 2, grad_v2.z);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
||||||
const at::Tensor& points,
|
const at::Tensor& points,
|
||||||
const at::Tensor& tris,
|
const at::Tensor& tris,
|
||||||
const at::Tensor& idx_tris,
|
const at::Tensor& idx_tris,
|
||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
// Check inputs are on the same device
|
return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists);
|
||||||
at::TensorArg points_t{points, "points", 1},
|
|
||||||
idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3},
|
|
||||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
|
||||||
at::CheckedFrom c = "FacePointDistanceBackwardCuda";
|
|
||||||
at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t});
|
|
||||||
at::checkAllSameType(c, {points_t, tris_t, grad_dists_t});
|
|
||||||
|
|
||||||
// Set the device for the kernel launch based on the device of the input
|
|
||||||
at::cuda::CUDAGuard device_guard(points.device());
|
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
|
|
||||||
const int64_t P = points.size(0);
|
|
||||||
const int64_t T = tris.size(0);
|
|
||||||
|
|
||||||
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
|
|
||||||
TORCH_CHECK(
|
|
||||||
(tris.size(1) == 3) && (tris.size(2) == 3),
|
|
||||||
"tris must be of shape Tx3x3");
|
|
||||||
TORCH_CHECK(idx_tris.size(0) == T);
|
|
||||||
TORCH_CHECK(grad_dists.size(0) == T);
|
|
||||||
|
|
||||||
// clang-format off
|
|
||||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
|
||||||
at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options());
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
if (grad_points.numel() == 0 || grad_tris.numel() == 0) {
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
|
||||||
return std::make_tuple(grad_points, grad_tris);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int blocks = 64;
|
|
||||||
const int threads = 512;
|
|
||||||
|
|
||||||
FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
|
||||||
points.contiguous().data_ptr<float>(),
|
|
||||||
tris.contiguous().data_ptr<float>(),
|
|
||||||
idx_tris.contiguous().data_ptr<int64_t>(),
|
|
||||||
grad_dists.contiguous().data_ptr<float>(),
|
|
||||||
grad_points.data_ptr<float>(),
|
|
||||||
grad_tris.data_ptr<float>(),
|
|
||||||
T);
|
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
|
||||||
return std::make_tuple(grad_points, grad_tris);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
|
Loading…
x
Reference in New Issue
Block a user