remove torch from cuda

Summary: Keep using at:: instead of torch:: so we don't need torch/extension.h and can keep other compilers happy.

Reviewed By: patricklabatut

Differential Revision: D31688436

fbshipit-source-id: 1825503da0104acaf1558d17300c02ef663bf538
This commit is contained in:
Jeremy Reizenstein 2021-10-18 03:37:08 -07:00 committed by Facebook GitHub Bot
parent 1a7442a483
commit 3953de47ee

View File

@ -9,10 +9,9 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
using torch::PackedTensorAccessor64;
using torch::RestrictPtrTraits;
using at::PackedTensorAccessor64;
using at::RestrictPtrTraits;
// A chunk of work is blocksize-many points.
// There are N clouds in the batch, and P points in each cloud.
@ -117,12 +116,12 @@ __global__ void PointsToVolumesForwardKernel(
}
void PointsToVolumesForwardCuda(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& volume_densities,
const torch::Tensor& volume_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const at::Tensor& points_3d,
const at::Tensor& points_features,
const at::Tensor& volume_densities,
const at::Tensor& volume_features,
const at::Tensor& grid_sizes,
const at::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat) {
@ -285,17 +284,17 @@ __global__ void PointsToVolumesBackwardKernel(
}
void PointsToVolumesBackwardCuda(
const torch::Tensor& points_3d,
const torch::Tensor& points_features,
const torch::Tensor& grid_sizes,
const torch::Tensor& mask,
const at::Tensor& points_3d,
const at::Tensor& points_features,
const at::Tensor& grid_sizes,
const at::Tensor& mask,
const float point_weight,
const bool align_corners,
const bool splat,
const torch::Tensor& grad_volume_densities,
const torch::Tensor& grad_volume_features,
const torch::Tensor& grad_points_3d,
const torch::Tensor& grad_points_features) {
const at::Tensor& grad_volume_densities,
const at::Tensor& grad_volume_features,
const at::Tensor& grad_points_3d,
const at::Tensor& grad_points_features) {
// Check inputs are on the same device
at::TensorArg points_3d_t{points_3d, "points_3d", 1},
points_features_t{points_features, "points_features", 2},