mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
remove torch from cuda
Summary: Keep using at:: instead of torch:: so we don't need torch/extension.h and can keep other compilers happy. Reviewed By: patricklabatut Differential Revision: D31688436 fbshipit-source-id: 1825503da0104acaf1558d17300c02ef663bf538
This commit is contained in:
parent
1a7442a483
commit
3953de47ee
@ -9,10 +9,9 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
using torch::PackedTensorAccessor64;
|
||||
using torch::RestrictPtrTraits;
|
||||
using at::PackedTensorAccessor64;
|
||||
using at::RestrictPtrTraits;
|
||||
|
||||
// A chunk of work is blocksize-many points.
|
||||
// There are N clouds in the batch, and P points in each cloud.
|
||||
@ -117,12 +116,12 @@ __global__ void PointsToVolumesForwardKernel(
|
||||
}
|
||||
|
||||
void PointsToVolumesForwardCuda(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& volume_densities,
|
||||
const torch::Tensor& volume_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
const at::Tensor& points_3d,
|
||||
const at::Tensor& points_features,
|
||||
const at::Tensor& volume_densities,
|
||||
const at::Tensor& volume_features,
|
||||
const at::Tensor& grid_sizes,
|
||||
const at::Tensor& mask,
|
||||
const float point_weight,
|
||||
const bool align_corners,
|
||||
const bool splat) {
|
||||
@ -285,17 +284,17 @@ __global__ void PointsToVolumesBackwardKernel(
|
||||
}
|
||||
|
||||
void PointsToVolumesBackwardCuda(
|
||||
const torch::Tensor& points_3d,
|
||||
const torch::Tensor& points_features,
|
||||
const torch::Tensor& grid_sizes,
|
||||
const torch::Tensor& mask,
|
||||
const at::Tensor& points_3d,
|
||||
const at::Tensor& points_features,
|
||||
const at::Tensor& grid_sizes,
|
||||
const at::Tensor& mask,
|
||||
const float point_weight,
|
||||
const bool align_corners,
|
||||
const bool splat,
|
||||
const torch::Tensor& grad_volume_densities,
|
||||
const torch::Tensor& grad_volume_features,
|
||||
const torch::Tensor& grad_points_3d,
|
||||
const torch::Tensor& grad_points_features) {
|
||||
const at::Tensor& grad_volume_densities,
|
||||
const at::Tensor& grad_volume_features,
|
||||
const at::Tensor& grad_points_3d,
|
||||
const at::Tensor& grad_points_features) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_3d_t{points_3d, "points_3d", 1},
|
||||
points_features_t{points_features, "points_features", 2},
|
||||
|
Loading…
x
Reference in New Issue
Block a user