diff --git a/pytorch3d/csrc/point_mesh/point_mesh_face.cu b/pytorch3d/csrc/point_mesh/point_mesh_face.cu index ec3e9f89..1ba42ddf 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_face.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_face.cu @@ -212,86 +212,135 @@ std::tuple PointFaceDistanceForwardCuda( points, 1, points_first_idx, tris, 3, tris_first_idx, max_points); } -__global__ void PointFaceBackwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ tris, // (T, 3, 3) - const int64_t* __restrict__ idx_points, // (P,) - const float* __restrict__ grad_dists, // (P,) - float* __restrict__ grad_points, // (P, 3) - float* __restrict__ grad_tris, // (T, 3, 3) - const size_t P) { - float3* points_f3 = (float3*)points; - float3* tris_f3 = (float3*)tris; +__global__ void DistanceBackwardKernel( + const float* __restrict__ objects, // (O * oD * 3) + const size_t objects_size, // O + const size_t objects_dim, // oD + const float* __restrict__ targets, // (T * tD * 3) + const size_t targets_dim, // tD + const int64_t* __restrict__ idx_objects, // (O,) + const float* __restrict__ grad_dists, // (O,) + float* __restrict__ grad_points, // ((O or T) * 3) + float* __restrict__ grad_face) { // ((O or T) * max(oD, tD) * 3) + // This kernel is used interchangeably to compute bi-directional backward + // distances between points and triangles/lines. The direction of the distance + // computed, i.e. point to triangle/line or triangle/line to point, depends on + // the order of the input arguments and is inferred based on their shape. The + // shape is used to distinguish between triangles and lines. Note that + // grad_points will always be used for the point data and grad_face for the + // edge/triangle + + // Set references to points/face based on whether objects/targets are which + float3* points_f3 = objects_dim == 1 ? (float3*)objects : (float3*)targets; + float3* face_f3 = objects_dim == 1 ? (float3*)targets : (float3*)objects; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = gridDim.x * blockDim.x; - for (size_t p = tid; p < P; p += stride) { - const float3 p_f3 = points_f3[p]; + for (size_t o = tid; o < objects_size; o += stride) { + const int64_t tidx = idx_objects[o]; - const int64_t tidx = idx_points[p]; - const float3 v0 = tris_f3[tidx * 3 + 0]; - const float3 v1 = tris_f3[tidx * 3 + 1]; - const float3 v2 = tris_f3[tidx * 3 + 2]; + size_t point_index = objects_dim == 1 ? o : tidx; + size_t face_index = objects_dim == 1 ? tidx * targets_dim : o * objects_dim; + bool isTriangle = objects_dim == 3 || targets_dim == 3; - const float grad_dist = grad_dists[p]; + float3 grad_point, grad_v0, grad_v1, grad_v2; + if (isTriangle) { + const auto grads = PointTriangle3DistanceBackward( + points_f3[point_index], + face_f3[face_index], + face_f3[face_index + 1], + face_f3[face_index + 2], + grad_dists[o]); + grad_point = thrust::get<0>(grads); + grad_v0 = thrust::get<1>(grads); + grad_v1 = thrust::get<2>(grads); + grad_v2 = thrust::get<3>(grads); + } else { + const auto grads = PointLine3DistanceBackward( + points_f3[point_index], + face_f3[face_index], + face_f3[face_index + 1], + grad_dists[o]); + grad_point = thrust::get<0>(grads); + grad_v0 = thrust::get<1>(grads); + grad_v1 = thrust::get<2>(grads); + } - const auto grads = - PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist); - const float3 grad_point = thrust::get<0>(grads); - const float3 grad_v0 = thrust::get<1>(grads); - const float3 grad_v1 = thrust::get<2>(grads); - const float3 grad_v2 = thrust::get<3>(grads); + atomicAdd(grad_points + point_index * 3 + 0, grad_point.x); + atomicAdd(grad_points + point_index * 3 + 1, grad_point.y); + atomicAdd(grad_points + point_index * 3 + 2, grad_point.z); - atomicAdd(grad_points + p * 3 + 0, grad_point.x); - atomicAdd(grad_points + p * 3 + 1, grad_point.y); - atomicAdd(grad_points + p * 3 + 2, grad_point.z); + atomicAdd(grad_face + face_index * 3 + 0 * 3 + 0, grad_v0.x); + atomicAdd(grad_face + face_index * 3 + 0 * 3 + 1, grad_v0.y); + atomicAdd(grad_face + face_index * 3 + 0 * 3 + 2, grad_v0.z); - atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 0, grad_v0.x); - atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 1, grad_v0.y); - atomicAdd(grad_tris + tidx * 3 * 3 + 0 * 3 + 2, grad_v0.z); + atomicAdd(grad_face + face_index * 3 + 1 * 3 + 0, grad_v1.x); + atomicAdd(grad_face + face_index * 3 + 1 * 3 + 1, grad_v1.y); + atomicAdd(grad_face + face_index * 3 + 1 * 3 + 2, grad_v1.z); - atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 0, grad_v1.x); - atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 1, grad_v1.y); - atomicAdd(grad_tris + tidx * 3 * 3 + 1 * 3 + 2, grad_v1.z); - - atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 0, grad_v2.x); - atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 1, grad_v2.y); - atomicAdd(grad_tris + tidx * 3 * 3 + 2 * 3 + 2, grad_v2.z); + if (isTriangle) { + atomicAdd(grad_face + face_index * 3 + 2 * 3 + 0, grad_v2.x); + atomicAdd(grad_face + face_index * 3 + 2 * 3 + 1, grad_v2.y); + atomicAdd(grad_face + face_index * 3 + 2 * 3 + 2, grad_v2.z); + } } } -std::tuple PointFaceDistanceBackwardCuda( - const at::Tensor& points, - const at::Tensor& tris, - const at::Tensor& idx_points, +std::tuple DistanceBackwardCuda( + const at::Tensor& objects, + const size_t objects_dim, + const at::Tensor& targets, + const size_t targets_dim, + const at::Tensor& idx_objects, const at::Tensor& grad_dists) { // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - idx_points_t{idx_points, "idx_points", 2}, tris_t{tris, "tris", 3}, + at::TensorArg objects_t{objects, "objects", 1}, + targets_t{targets, "targets", 2}, + idx_objects_t{idx_objects, "idx_objects", 3}, grad_dists_t{grad_dists, "grad_dists", 4}; - at::CheckedFrom c = "PointFaceDistanceBackwardCuda"; - at::checkAllSameGPU(c, {points_t, idx_points_t, tris_t, grad_dists_t}); - at::checkAllSameType(c, {points_t, tris_t, grad_dists_t}); + at::CheckedFrom c = "DistanceBackwardCuda"; + at::checkAllSameGPU(c, {objects_t, targets_t, idx_objects_t, grad_dists_t}); + at::checkAllSameType(c, {objects_t, targets_t, grad_dists_t}); // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); + at::cuda::CUDAGuard device_guard(objects.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const int64_t P = points.size(0); - const int64_t T = tris.size(0); + const int64_t objects_size = objects.size(0); + const int64_t targets_size = targets.size(0); - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (tris.size(1) == 3) && (tris.size(2) == 3), - "tris must be of shape Tx3x3"); - TORCH_CHECK(idx_points.size(0) == P); - TORCH_CHECK(grad_dists.size(0) == P); + at::Tensor grad_points; + at::Tensor grad_tris; - // clang-format off - at::Tensor grad_points = at::zeros({P, 3}, points.options()); - at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); - // clang-format on + TORCH_CHECK(idx_objects.size(0) == objects_size); + TORCH_CHECK(grad_dists.size(0) == objects_size); + if (objects_dim == 1) { + TORCH_CHECK( + targets_dim >= 2 && targets_dim <= 3, + "either object or target must be edge or face"); + TORCH_CHECK(objects.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + targets.size(2) == 3, + "face must be of shape Tx3x3, lines must be of shape Tx2x3"); + // clang-format off + grad_points = at::zeros({objects_size, 3}, objects.options()); + grad_tris = at::zeros({targets_size, int64_t(targets_dim), 3}, targets.options()); + // clang-format on + } else { + TORCH_CHECK(targets_dim == 1, "either object or target must be point"); + TORCH_CHECK( + objects_dim >= 2 && objects_dim <= 3, + "either object or target must be edge or face"); + TORCH_CHECK(targets.size(1) == 3, "points must be of shape Px3"); + TORCH_CHECK( + objects.size(2) == 3, + "face must be of shape Tx3x3, lines must be of shape Tx2x3"); + // clang-format off + grad_points = at::zeros({targets_size, 3}, targets.options()); + grad_tris = at::zeros({objects_size, int64_t(objects_dim), 3}, objects.options()); + // clang-format on + } if (grad_points.numel() == 0 || grad_tris.numel() == 0) { AT_CUDA_CHECK(cudaGetLastError()); @@ -301,19 +350,29 @@ std::tuple PointFaceDistanceBackwardCuda( const int blocks = 64; const int threads = 512; - PointFaceBackwardKernel<<>>( - points.contiguous().data_ptr(), - tris.contiguous().data_ptr(), - idx_points.contiguous().data_ptr(), + DistanceBackwardKernel<<>>( + objects.contiguous().data_ptr(), + objects_size, + objects_dim, + targets.contiguous().data_ptr(), + targets_dim, + idx_objects.contiguous().data_ptr(), grad_dists.contiguous().data_ptr(), grad_points.data_ptr(), - grad_tris.data_ptr(), - P); + grad_tris.data_ptr()); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_points, grad_tris); } +std::tuple PointFaceDistanceBackwardCuda( + const at::Tensor& points, + const at::Tensor& tris, + const at::Tensor& idx_points, + const at::Tensor& grad_dists) { + return DistanceBackwardCuda(points, 1, tris, 3, idx_points, grad_dists); +} + // **************************************************************************** // * FacePointDistance * // **************************************************************************** @@ -328,107 +387,12 @@ std::tuple FacePointDistanceForwardCuda( tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris); } -__global__ void FacePointBackwardKernel( - const float* __restrict__ points, // (P, 3) - const float* __restrict__ tris, // (T, 3, 3) - const int64_t* __restrict__ idx_tris, // (T,) - const float* __restrict__ grad_dists, // (T,) - float* __restrict__ grad_points, // (P, 3) - float* __restrict__ grad_tris, // (T, 3, 3) - const size_t T) { - float3* points_f3 = (float3*)points; - float3* tris_f3 = (float3*)tris; - - const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const size_t stride = gridDim.x * blockDim.x; - - for (size_t t = tid; t < T; t += stride) { - const float3 v0 = tris_f3[t * 3 + 0]; - const float3 v1 = tris_f3[t * 3 + 1]; - const float3 v2 = tris_f3[t * 3 + 2]; - - const int64_t pidx = idx_tris[t]; - - const float3 p_f3 = points_f3[pidx]; - - const float grad_dist = grad_dists[t]; - - const auto grads = - PointTriangle3DistanceBackward(p_f3, v0, v1, v2, grad_dist); - const float3 grad_point = thrust::get<0>(grads); - const float3 grad_v0 = thrust::get<1>(grads); - const float3 grad_v1 = thrust::get<2>(grads); - const float3 grad_v2 = thrust::get<3>(grads); - - atomicAdd(grad_points + pidx * 3 + 0, grad_point.x); - atomicAdd(grad_points + pidx * 3 + 1, grad_point.y); - atomicAdd(grad_points + pidx * 3 + 2, grad_point.z); - - atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 0, grad_v0.x); - atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 1, grad_v0.y); - atomicAdd(grad_tris + t * 3 * 3 + 0 * 3 + 2, grad_v0.z); - - atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 0, grad_v1.x); - atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 1, grad_v1.y); - atomicAdd(grad_tris + t * 3 * 3 + 1 * 3 + 2, grad_v1.z); - - atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 0, grad_v2.x); - atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 1, grad_v2.y); - atomicAdd(grad_tris + t * 3 * 3 + 2 * 3 + 2, grad_v2.z); - } -} - std::tuple FacePointDistanceBackwardCuda( const at::Tensor& points, const at::Tensor& tris, const at::Tensor& idx_tris, const at::Tensor& grad_dists) { - // Check inputs are on the same device - at::TensorArg points_t{points, "points", 1}, - idx_tris_t{idx_tris, "idx_tris", 2}, tris_t{tris, "tris", 3}, - grad_dists_t{grad_dists, "grad_dists", 4}; - at::CheckedFrom c = "FacePointDistanceBackwardCuda"; - at::checkAllSameGPU(c, {points_t, idx_tris_t, tris_t, grad_dists_t}); - at::checkAllSameType(c, {points_t, tris_t, grad_dists_t}); - - // Set the device for the kernel launch based on the device of the input - at::cuda::CUDAGuard device_guard(points.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int64_t P = points.size(0); - const int64_t T = tris.size(0); - - TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3"); - TORCH_CHECK( - (tris.size(1) == 3) && (tris.size(2) == 3), - "tris must be of shape Tx3x3"); - TORCH_CHECK(idx_tris.size(0) == T); - TORCH_CHECK(grad_dists.size(0) == T); - - // clang-format off - at::Tensor grad_points = at::zeros({P, 3}, points.options()); - at::Tensor grad_tris = at::zeros({T, 3, 3}, tris.options()); - // clang-format on - - if (grad_points.numel() == 0 || grad_tris.numel() == 0) { - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_tris); - } - - const int blocks = 64; - const int threads = 512; - - FacePointBackwardKernel<<>>( - points.contiguous().data_ptr(), - tris.contiguous().data_ptr(), - idx_tris.contiguous().data_ptr(), - grad_dists.contiguous().data_ptr(), - grad_points.data_ptr(), - grad_tris.data_ptr(), - T); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_points, grad_tris); + return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists); } // ****************************************************************************