mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-07 04:36:00 +08:00
Summary: Adds CHECK_CPU macros that checks if a tensor is on the CPU device throughout csrc directories up to `marching_cubes`. Directories updated include those in `gather_scatter`, `interp_face_attrs`, `iou_box3d`, `knn`, and `marching_cubes`. Note that this is the second part of a larger change, and to keep diffs better organized, subsequent diffs will update the remaining directories. Reviewed By: bottler Differential Revision: D77558550 fbshipit-source-id: 762a0fe88548dc8d0901b198a11c40d0c36e173f
162 lines
5.4 KiB
C++
162 lines
5.4 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#pragma once
|
|
#include <torch/extension.h>
|
|
#include <tuple>
|
|
#include "utils/pytorch3d_cutils.h"
|
|
|
|
// Compute indices of K nearest neighbors in pointcloud p2 to points
|
|
// in pointcloud p1.
|
|
//
|
|
// Args:
|
|
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
|
// containing P1 points of dimension D.
|
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
|
// containing P2 points of dimension D.
|
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
|
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
|
// K: int giving the number of nearest points to return.
|
|
// version: Integer telling which implementation to use.
|
|
//
|
|
// Returns:
|
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
|
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
|
// It is padded with zeros so that it can be used easily in a later
|
|
// gather() operation.
|
|
//
|
|
// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
|
|
// distance from each point p1[n, p, :] to its K neighbors
|
|
// p2[n, p1_neighbor_idx[n, p, k], :].
|
|
|
|
// CPU implementation.
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const int norm,
|
|
const int K);
|
|
|
|
// CUDA implementation
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const int norm,
|
|
const int K,
|
|
const int version);
|
|
|
|
// Implementation which is exposed.
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const int norm,
|
|
const int K,
|
|
const int version) {
|
|
if (p1.is_cuda() || p2.is_cuda()) {
|
|
#ifdef WITH_CUDA
|
|
CHECK_CUDA(p1);
|
|
CHECK_CUDA(p2);
|
|
return KNearestNeighborIdxCuda(
|
|
p1, p2, lengths1, lengths2, norm, K, version);
|
|
#else
|
|
AT_ERROR("Not compiled with GPU support.");
|
|
#endif
|
|
}
|
|
CHECK_CPU(p1);
|
|
CHECK_CPU(p2);
|
|
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
|
}
|
|
|
|
// Compute gradients with respect to p1 and p2
|
|
//
|
|
// Args:
|
|
// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
|
// containing P1 points of dimension D.
|
|
// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
|
// containing P2 points of dimension D.
|
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
|
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
|
// p1_neighbor_idx[n, i, k] = j means that the kth nearest
|
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
|
// It is padded with zeros so that it can be used easily in a later
|
|
// gather() operation. This is computed from the forward pass.
|
|
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
|
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
|
// gradients.
|
|
//
|
|
// Returns:
|
|
// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
|
|
// wrt p1.
|
|
// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
|
|
// wrt p2.
|
|
|
|
// CPU implementation.
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const at::Tensor& idxs,
|
|
const int norm,
|
|
const at::Tensor& grad_dists);
|
|
|
|
// CUDA implementation
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const at::Tensor& idxs,
|
|
const int norm,
|
|
const at::Tensor& grad_dists);
|
|
|
|
// Implementation which is exposed.
|
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
|
const at::Tensor& p1,
|
|
const at::Tensor& p2,
|
|
const at::Tensor& lengths1,
|
|
const at::Tensor& lengths2,
|
|
const at::Tensor& idxs,
|
|
const int norm,
|
|
const at::Tensor& grad_dists) {
|
|
if (p1.is_cuda() || p2.is_cuda()) {
|
|
#ifdef WITH_CUDA
|
|
CHECK_CUDA(p1);
|
|
CHECK_CUDA(p2);
|
|
return KNearestNeighborBackwardCuda(
|
|
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
|
#else
|
|
AT_ERROR("Not compiled with GPU support.");
|
|
#endif
|
|
}
|
|
CHECK_CPU(p1);
|
|
CHECK_CPU(p2);
|
|
return KNearestNeighborBackwardCpu(
|
|
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
|
}
|
|
|
|
// Utility to check whether a KNN version can be used.
|
|
//
|
|
// Args:
|
|
// version: Integer in the range 0 <= version <= 3 indicating one of our
|
|
// KNN implementations.
|
|
// D: Number of dimensions for the input and query point clouds
|
|
// K: Number of neighbors to be found
|
|
//
|
|
// Returns:
|
|
// Whether the indicated KNN version can be used.
|
|
bool KnnCheckVersion(int version, const int64_t D, const int64_t K);
|