mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 18:26:01 +08:00
Cuda updates
Summary: Updates to: - enable cuda kernel launches on any GPU (not just the default) - cuda and contiguous checks for all kernels - checks to ensure all tensors are on the same device - error reporting in the cuda kernels - cuda tests now run on a random device not just the default Reviewed By: jcjohnson, gkioxari Differential Revision: D21215280 fbshipit-source-id: 1bedc9fe6c35e9e920bdc4d78ed12865b1005519
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c9267ab7af
commit
c3d636dc8c
@@ -1,6 +1,8 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <math.h>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
@@ -145,13 +147,25 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const int points_per_pixel) {
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_points_per_cloud.size(0) != cloud_to_packed_first_idx.size(0)) {
|
||||
AT_ERROR(
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
}
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
cloud_to_packed_first_idx_t{
|
||||
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
|
||||
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
|
||||
at::CheckedFrom c = "RasterizePointsNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
TORCH_CHECK(
|
||||
points.ndimension() == 2 && points.size(1) == 3,
|
||||
"points must have dimensions (num_points, 3)");
|
||||
TORCH_CHECK(
|
||||
num_points_per_cloud.size(0) == cloud_to_packed_first_idx.size(0),
|
||||
"num_points_per_cloud must have same size first dimension as cloud_to_packed_first_idx");
|
||||
|
||||
const int N = num_points_per_cloud.size(0); // batch size.
|
||||
const int S = image_size;
|
||||
@@ -169,9 +183,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsNaiveCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
@@ -182,6 +201,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
point_idxs.contiguous().data_ptr<int32_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
@@ -323,14 +344,28 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
TORCH_CHECK(
|
||||
points.ndimension() == 2 && points.size(1) == 3,
|
||||
"points must have dimensions (num_points, 3)");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
cloud_to_packed_first_idx_t{
|
||||
cloud_to_packed_first_idx, "cloud_to_packed_first_idx", 2},
|
||||
num_points_per_cloud_t{num_points_per_cloud, "num_points_per_cloud", 3};
|
||||
at::CheckedFrom c = "RasterizePointsCoarseCuda";
|
||||
at::checkAllSameGPU(
|
||||
c, {points_t, cloud_to_packed_first_idx_t, num_points_per_cloud_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = num_points_per_cloud.size(0);
|
||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
||||
const int M = max_points_per_bin;
|
||||
|
||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||
}
|
||||
if (num_bins >= 22) {
|
||||
// Make sure we do not use too much shared memory.
|
||||
std::stringstream ss;
|
||||
@@ -340,12 +375,18 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
auto opts = points.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
if (bin_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
const int chunk_size = 512;
|
||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||
const size_t blocks = 64;
|
||||
const size_t threads = 512;
|
||||
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
@@ -358,6 +399,8 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
M,
|
||||
points_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
@@ -448,13 +491,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const float radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
bin_points_t{bin_points, "bin_points", 2};
|
||||
at::CheckedFrom c = "RasterizePointsFineCuda";
|
||||
at::checkAllSameGPU(c, {points_t, bin_points_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int N = bin_points.size(0);
|
||||
const int B = bin_points.size(1); // num_bins
|
||||
const int M = bin_points.size(3);
|
||||
const int S = image_size;
|
||||
const int K = points_per_pixel;
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 8");
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
@@ -462,9 +515,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
at::Tensor pix_dists = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
||||
if (point_idxs.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
RasterizePointsFineCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>(),
|
||||
radius,
|
||||
@@ -478,6 +536,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
@@ -537,6 +596,19 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
const at::Tensor& idxs, // (N, H, W, K)
|
||||
const at::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const at::Tensor& grad_dists) { // (N, H, W, K)
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1}, idxs_t{idxs, "idxs", 2},
|
||||
grad_zbuf_t{grad_zbuf, "grad_zbuf", 3},
|
||||
grad_dists_t{grad_dists, "grad_dists", 4};
|
||||
at::CheckedFrom c = "RasterizePointsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {points_t, idxs_t, grad_zbuf_t, grad_dists_t});
|
||||
at::checkAllSameType(c, {points_t, grad_zbuf_t, grad_dists_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int P = points.size(0);
|
||||
const int N = idxs.size(0);
|
||||
const int H = idxs.size(1);
|
||||
@@ -544,10 +616,16 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
const int K = idxs.size(3);
|
||||
|
||||
at::Tensor grad_points = at::zeros({P, 3}, points.options());
|
||||
|
||||
if (grad_points.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_points;
|
||||
}
|
||||
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 64;
|
||||
|
||||
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>(
|
||||
RasterizePointsBackwardCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
idxs.contiguous().data_ptr<int32_t>(),
|
||||
N,
|
||||
@@ -559,5 +637,6 @@ at::Tensor RasterizePointsBackwardCuda(
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_points.contiguous().data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_points;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * NAIVE RASTERIZATION *
|
||||
@@ -66,6 +67,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
num_points_per_cloud.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
||||
return RasterizePointsNaiveCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -140,6 +144,9 @@ torch::Tensor RasterizePointsCoarse(
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
num_points_per_cloud.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
||||
return RasterizePointsCoarseCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -208,6 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
const int points_per_pixel) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(bin_points);
|
||||
return RasterizePointsFineCuda(
|
||||
points, bin_points, image_size, radius, bin_size, points_per_pixel);
|
||||
#else
|
||||
@@ -257,6 +266,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
const torch::Tensor& grad_dists) {
|
||||
if (points.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(points);
|
||||
CHECK_CONTIGUOUS_CUDA(idxs);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
|
||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
||||
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
|
||||
Reference in New Issue
Block a user