mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Make cuda tensors contiguous in host function and remove contiguous check
Summary: Update the cuda kernels to: - remove contiguous checks for the grad tensors and for cpu functions which use accessors - for cuda implementations call `.contiguous()` on all tensors in the host function before invoking the kernel Reviewed By: gkioxari Differential Revision: D21598008 fbshipit-source-id: 9b97bda4582fd4269c8a00999874d4552a1aea2d
This commit is contained in:
parent
a8377f1f06
commit
3fef506895
@ -168,6 +168,8 @@ at::Tensor alphaCompositeCudaForward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
@ -211,6 +213,8 @@ std::tuple<at::Tensor, at::Tensor> alphaCompositeCudaBackward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
|
@ -60,18 +60,14 @@ torch::Tensor alphaCompositeForward(
|
|||||||
|
|
||||||
if (features.is_cuda()) {
|
if (features.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
return alphaCompositeCudaForward(features, alphas, points_idx);
|
return alphaCompositeCudaForward(features, alphas, points_idx);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -88,10 +84,10 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
|||||||
|
|
||||||
if (grad_outputs.is_cuda()) {
|
if (grad_outputs.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
CHECK_CUDA(grad_outputs);
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
|
|
||||||
return alphaCompositeCudaBackward(
|
return alphaCompositeCudaBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
@ -99,11 +95,6 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(grad_outputs);
|
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return alphaCompositeCpuBackward(
|
return alphaCompositeCpuBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
|
@ -183,6 +183,8 @@ at::Tensor weightedSumNormCudaForward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
@ -227,6 +229,8 @@ std::tuple<at::Tensor, at::Tensor> weightedSumNormCudaBackward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
|
@ -58,19 +58,15 @@ torch::Tensor weightedSumNormForward(
|
|||||||
|
|
||||||
if (features.is_cuda()) {
|
if (features.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
|
|
||||||
return weightedSumNormCudaForward(features, alphas, points_idx);
|
return weightedSumNormCudaForward(features, alphas, points_idx);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,10 +83,10 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
|||||||
|
|
||||||
if (grad_outputs.is_cuda()) {
|
if (grad_outputs.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
CHECK_CUDA(grad_outputs);
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
|
|
||||||
return weightedSumNormCudaBackward(
|
return weightedSumNormCudaBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
@ -98,11 +94,6 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
|||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(grad_outputs);
|
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return weightedSumNormCpuBackward(
|
return weightedSumNormCpuBackward(
|
||||||
grad_outputs, features, alphas, points_idx);
|
grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
|
@ -142,6 +142,8 @@ at::Tensor weightedSumCudaForward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
result.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
@ -185,6 +187,8 @@ std::tuple<at::Tensor, at::Tensor> weightedSumCudaBackward(
|
|||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock, 0, stream>>>(
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
// As we are using packed accessors here the tensors
|
||||||
|
// do not need to be made contiguous.
|
||||||
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
grad_features.packed_accessor64<float, 2, at::RestrictPtrTraits>(),
|
||||||
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_alphas.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
grad_outputs.packed_accessor64<float, 4, at::RestrictPtrTraits>(),
|
||||||
|
@ -58,18 +58,14 @@ torch::Tensor weightedSumForward(
|
|||||||
|
|
||||||
if (features.is_cuda()) {
|
if (features.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
return weightedSumCudaForward(features, alphas, points_idx);
|
return weightedSumCudaForward(features, alphas, points_idx);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return weightedSumCpuForward(features, alphas, points_idx);
|
return weightedSumCpuForward(features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -86,21 +82,16 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
|||||||
|
|
||||||
if (grad_outputs.is_cuda()) {
|
if (grad_outputs.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
CHECK_CUDA(grad_outputs);
|
||||||
CHECK_CONTIGUOUS_CUDA(features);
|
CHECK_CUDA(features);
|
||||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
CHECK_CUDA(alphas);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
CHECK_CUDA(points_idx);
|
||||||
|
|
||||||
return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx);
|
return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CHECK_CONTIGUOUS(grad_outputs);
|
|
||||||
CHECK_CONTIGUOUS(features);
|
|
||||||
CHECK_CONTIGUOUS(alphas);
|
|
||||||
CHECK_CONTIGUOUS(points_idx);
|
|
||||||
|
|
||||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -239,8 +239,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
|||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] {
|
verts.scalar_type(), "face_areas_normals_forward_cuda", ([&] {
|
||||||
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
FaceAreasNormalsForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
verts.data_ptr<scalar_t>(),
|
verts.contiguous().data_ptr<scalar_t>(),
|
||||||
faces.data_ptr<int64_t>(),
|
faces.contiguous().data_ptr<int64_t>(),
|
||||||
areas.data_ptr<scalar_t>(),
|
areas.data_ptr<scalar_t>(),
|
||||||
normals.data_ptr<scalar_t>(),
|
normals.data_ptr<scalar_t>(),
|
||||||
V,
|
V,
|
||||||
@ -282,10 +282,10 @@ at::Tensor FaceAreasNormalsBackwardCuda(
|
|||||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||||
// doubles. Currently, support is for floats only.
|
// doubles. Currently, support is for floats only.
|
||||||
FaceAreasNormalsBackwardKernel<<<blocks, threads, 0, stream>>>(
|
FaceAreasNormalsBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
grad_areas.data_ptr<float>(),
|
grad_areas.contiguous().data_ptr<float>(),
|
||||||
grad_normals.data_ptr<float>(),
|
grad_normals.contiguous().data_ptr<float>(),
|
||||||
verts.data_ptr<float>(),
|
verts.contiguous().data_ptr<float>(),
|
||||||
faces.data_ptr<int64_t>(),
|
faces.contiguous().data_ptr<int64_t>(),
|
||||||
grad_verts.data_ptr<float>(),
|
grad_verts.data_ptr<float>(),
|
||||||
V,
|
V,
|
||||||
F);
|
F);
|
||||||
|
@ -47,8 +47,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
|||||||
const at::Tensor faces) {
|
const at::Tensor faces) {
|
||||||
if (verts.is_cuda() && faces.is_cuda()) {
|
if (verts.is_cuda() && faces.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(verts);
|
CHECK_CUDA(verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(faces);
|
CHECK_CUDA(faces);
|
||||||
return FaceAreasNormalsForwardCuda(verts, faces);
|
return FaceAreasNormalsForwardCuda(verts, faces);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -65,10 +65,10 @@ at::Tensor FaceAreasNormalsBackward(
|
|||||||
const at::Tensor faces) {
|
const at::Tensor faces) {
|
||||||
if (verts.is_cuda() && faces.is_cuda()) {
|
if (verts.is_cuda() && faces.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(verts);
|
CHECK_CUDA(verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(faces);
|
CHECK_CUDA(faces);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_areas);
|
CHECK_CUDA(grad_areas);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_normals);
|
CHECK_CUDA(grad_normals);
|
||||||
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
|
return FaceAreasNormalsBackwardCuda(grad_areas, grad_normals, verts, faces);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
@ -72,8 +72,8 @@ at::Tensor GatherScatterCuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GatherScatterCudaKernel<<<blocks, threads, 0, stream>>>(
|
GatherScatterCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||||
input.data_ptr<float>(),
|
input.contiguous().data_ptr<float>(),
|
||||||
edges.data_ptr<int64_t>(),
|
edges.contiguous().data_ptr<int64_t>(),
|
||||||
output.data_ptr<float>(),
|
output.data_ptr<float>(),
|
||||||
directed,
|
directed,
|
||||||
backward,
|
backward,
|
||||||
|
@ -35,8 +35,8 @@ at::Tensor GatherScatter(
|
|||||||
bool backward) {
|
bool backward) {
|
||||||
if (input.is_cuda() && edges.is_cuda()) {
|
if (input.is_cuda() && edges.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(input);
|
CHECK_CUDA(input);
|
||||||
CHECK_CONTIGUOUS_CUDA(edges);
|
CHECK_CUDA(edges);
|
||||||
return GatherScatterCuda(input, edges, directed, backward);
|
return GatherScatterCuda(input, edges, directed, backward);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
@ -347,13 +347,13 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
const size_t threads = 256;
|
const size_t threads = 256;
|
||||||
const size_t blocks = 256;
|
const size_t blocks = 256;
|
||||||
if (version == 0) {
|
if (version == 0) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
KNearestNeighborKernelV0<scalar_t>
|
p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
<<<blocks, threads, 0, stream>>>(
|
KNearestNeighborKernelV0<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.contiguous().data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.contiguous().data_ptr<scalar_t>(),
|
||||||
lengths1.data_ptr<int64_t>(),
|
lengths1.contiguous().data_ptr<int64_t>(),
|
||||||
lengths2.data_ptr<int64_t>(),
|
lengths2.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -372,10 +372,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
D,
|
D,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.contiguous().data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.contiguous().data_ptr<scalar_t>(),
|
||||||
lengths1.data_ptr<int64_t>(),
|
lengths1.contiguous().data_ptr<int64_t>(),
|
||||||
lengths2.data_ptr<int64_t>(),
|
lengths2.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -396,10 +396,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
K_64,
|
K_64,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.contiguous().data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.contiguous().data_ptr<scalar_t>(),
|
||||||
lengths1.data_ptr<int64_t>(),
|
lengths1.contiguous().data_ptr<int64_t>(),
|
||||||
lengths2.data_ptr<int64_t>(),
|
lengths2.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -419,10 +419,10 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
K_64,
|
K_64,
|
||||||
blocks,
|
blocks,
|
||||||
threads,
|
threads,
|
||||||
p1.data_ptr<scalar_t>(),
|
p1.contiguous().data_ptr<scalar_t>(),
|
||||||
p2.data_ptr<scalar_t>(),
|
p2.contiguous().data_ptr<scalar_t>(),
|
||||||
lengths1.data_ptr<int64_t>(),
|
lengths1.contiguous().data_ptr<int64_t>(),
|
||||||
lengths2.data_ptr<int64_t>(),
|
lengths2.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<scalar_t>(),
|
dists.data_ptr<scalar_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
@ -525,12 +525,12 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
|||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
|
KNearestNeighborBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
p1.data_ptr<float>(),
|
p1.contiguous().data_ptr<float>(),
|
||||||
p2.data_ptr<float>(),
|
p2.contiguous().data_ptr<float>(),
|
||||||
lengths1.data_ptr<int64_t>(),
|
lengths1.contiguous().data_ptr<int64_t>(),
|
||||||
lengths2.data_ptr<int64_t>(),
|
lengths2.contiguous().data_ptr<int64_t>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_p1.data_ptr<float>(),
|
grad_p1.data_ptr<float>(),
|
||||||
grad_p2.data_ptr<float>(),
|
grad_p2.data_ptr<float>(),
|
||||||
N,
|
N,
|
||||||
|
@ -56,8 +56,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
int version) {
|
int version) {
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(p1);
|
CHECK_CUDA(p1);
|
||||||
CHECK_CONTIGUOUS_CUDA(p2);
|
CHECK_CUDA(p2);
|
||||||
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -117,8 +117,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
|||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(p1);
|
CHECK_CUDA(p1);
|
||||||
CHECK_CONTIGUOUS_CUDA(p2);
|
CHECK_CUDA(p2);
|
||||||
return KNearestNeighborBackwardCuda(
|
return KNearestNeighborBackwardCuda(
|
||||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||||
#else
|
#else
|
||||||
|
@ -146,8 +146,8 @@ at::Tensor PackedToPaddedCuda(
|
|||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
|
inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
|
||||||
PackedToPaddedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
PackedToPaddedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
inputs_packed.data_ptr<scalar_t>(),
|
inputs_packed.contiguous().data_ptr<scalar_t>(),
|
||||||
first_idxs.data_ptr<int64_t>(),
|
first_idxs.contiguous().data_ptr<int64_t>(),
|
||||||
inputs_padded.data_ptr<scalar_t>(),
|
inputs_padded.data_ptr<scalar_t>(),
|
||||||
batch_size,
|
batch_size,
|
||||||
max_size,
|
max_size,
|
||||||
@ -157,8 +157,8 @@ at::Tensor PackedToPaddedCuda(
|
|||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
|
inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
|
||||||
PackedToPaddedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
PackedToPaddedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
inputs_packed.data_ptr<scalar_t>(),
|
inputs_packed.contiguous().data_ptr<scalar_t>(),
|
||||||
first_idxs.data_ptr<int64_t>(),
|
first_idxs.contiguous().data_ptr<int64_t>(),
|
||||||
inputs_padded.data_ptr<scalar_t>(),
|
inputs_padded.data_ptr<scalar_t>(),
|
||||||
batch_size,
|
batch_size,
|
||||||
max_size,
|
max_size,
|
||||||
@ -209,8 +209,8 @@ at::Tensor PaddedToPackedCuda(
|
|||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
|
inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
|
||||||
PaddedToPackedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
PaddedToPackedKernelD1<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
inputs_padded.data_ptr<scalar_t>(),
|
inputs_padded.contiguous().data_ptr<scalar_t>(),
|
||||||
first_idxs.data_ptr<int64_t>(),
|
first_idxs.contiguous().data_ptr<int64_t>(),
|
||||||
inputs_packed.data_ptr<scalar_t>(),
|
inputs_packed.data_ptr<scalar_t>(),
|
||||||
batch_size,
|
batch_size,
|
||||||
max_size,
|
max_size,
|
||||||
@ -220,8 +220,8 @@ at::Tensor PaddedToPackedCuda(
|
|||||||
AT_DISPATCH_FLOATING_TYPES(
|
AT_DISPATCH_FLOATING_TYPES(
|
||||||
inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
|
inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
|
||||||
PaddedToPackedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
PaddedToPackedKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
inputs_padded.data_ptr<scalar_t>(),
|
inputs_padded.contiguous().data_ptr<scalar_t>(),
|
||||||
first_idxs.data_ptr<int64_t>(),
|
first_idxs.contiguous().data_ptr<int64_t>(),
|
||||||
inputs_packed.data_ptr<scalar_t>(),
|
inputs_packed.data_ptr<scalar_t>(),
|
||||||
batch_size,
|
batch_size,
|
||||||
max_size,
|
max_size,
|
||||||
|
@ -75,8 +75,8 @@ at::Tensor PackedToPadded(
|
|||||||
const int64_t max_size) {
|
const int64_t max_size) {
|
||||||
if (inputs_packed.is_cuda()) {
|
if (inputs_packed.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(inputs_packed);
|
CHECK_CUDA(inputs_packed);
|
||||||
CHECK_CONTIGUOUS_CUDA(first_idxs);
|
CHECK_CUDA(first_idxs);
|
||||||
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
|
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -92,8 +92,8 @@ at::Tensor PaddedToPacked(
|
|||||||
const int64_t num_inputs) {
|
const int64_t num_inputs) {
|
||||||
if (inputs_padded.is_cuda()) {
|
if (inputs_padded.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(inputs_padded);
|
CHECK_CUDA(inputs_padded);
|
||||||
CHECK_CONTIGUOUS_CUDA(first_idxs);
|
CHECK_CUDA(first_idxs);
|
||||||
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
|
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
@ -144,15 +144,16 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
|
|||||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||||
|
|
||||||
PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
PointEdgeForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
points_first_idx.data_ptr<int64_t>(),
|
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
segms_first_idx.data_ptr<int64_t>(),
|
segms_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
B,
|
B,
|
||||||
P,
|
P,
|
||||||
S);
|
S);
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return std::make_tuple(dists, idxs);
|
return std::make_tuple(dists, idxs);
|
||||||
}
|
}
|
||||||
@ -240,10 +241,10 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
|
|||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
|
PointEdgeBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
idx_points.data_ptr<int64_t>(),
|
idx_points.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_segms.data_ptr<float>(),
|
grad_segms.data_ptr<float>(),
|
||||||
P);
|
P);
|
||||||
@ -386,10 +387,10 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
|
|||||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||||
|
|
||||||
EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
EdgePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
points_first_idx.data_ptr<int64_t>(),
|
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
segms_first_idx.data_ptr<int64_t>(),
|
segms_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
B,
|
B,
|
||||||
@ -478,10 +479,10 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
|
|||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
EdgePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
idx_segms.data_ptr<int64_t>(),
|
idx_segms.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_segms.data_ptr<float>(),
|
grad_segms.data_ptr<float>(),
|
||||||
S);
|
S);
|
||||||
@ -550,8 +551,8 @@ at::Tensor PointEdgeArrayDistanceForwardCuda(
|
|||||||
const size_t threads = 64;
|
const size_t threads = 64;
|
||||||
|
|
||||||
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
PointEdgeArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
P,
|
P,
|
||||||
S);
|
S);
|
||||||
@ -638,9 +639,9 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCuda(
|
|||||||
const size_t threads = 64;
|
const size_t threads = 64;
|
||||||
|
|
||||||
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
PointEdgeArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
segms.data_ptr<float>(),
|
segms.contiguous().data_ptr<float>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_segms.data_ptr<float>(),
|
grad_segms.data_ptr<float>(),
|
||||||
P,
|
P,
|
||||||
|
@ -54,10 +54,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
|||||||
const int64_t max_points) {
|
const int64_t max_points) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
CHECK_CUDA(points_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
|
CHECK_CUDA(segms_first_idx);
|
||||||
return PointEdgeDistanceForwardCuda(
|
return PointEdgeDistanceForwardCuda(
|
||||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||||
#else
|
#else
|
||||||
@ -98,10 +98,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(idx_points);
|
CHECK_CUDA(idx_points);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
return PointEdgeDistanceBackwardCuda(points, segms, idx_points, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -158,10 +158,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
|||||||
const int64_t max_segms) {
|
const int64_t max_segms) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
CHECK_CUDA(points_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms_first_idx);
|
CHECK_CUDA(segms_first_idx);
|
||||||
return EdgePointDistanceForwardCuda(
|
return EdgePointDistanceForwardCuda(
|
||||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||||
#else
|
#else
|
||||||
@ -202,10 +202,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(idx_segms);
|
CHECK_CUDA(idx_segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
return EdgePointDistanceBackwardCuda(points, segms, idx_segms, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -247,8 +247,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
|||||||
const torch::Tensor& segms) {
|
const torch::Tensor& segms) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
return PointEdgeArrayDistanceForwardCuda(points, segms);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -283,9 +283,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(segms);
|
CHECK_CUDA(segms);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
return PointEdgeArrayDistanceBackwardCuda(points, segms, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
@ -145,10 +145,10 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
|
|||||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||||
|
|
||||||
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
PointFaceForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
points_first_idx.data_ptr<int64_t>(),
|
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
tris_first_idx.data_ptr<int64_t>(),
|
tris_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
B,
|
B,
|
||||||
@ -249,10 +249,10 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
|
|||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
|
PointFaceBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
idx_points.data_ptr<int64_t>(),
|
idx_points.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_tris.data_ptr<float>(),
|
grad_tris.data_ptr<float>(),
|
||||||
P);
|
P);
|
||||||
@ -396,10 +396,10 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
|
|||||||
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
size_t shared_size = threads * sizeof(size_t) + threads * sizeof(int64_t);
|
||||||
|
|
||||||
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
FacePointForwardKernel<<<blocks, threads, shared_size, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
points_first_idx.data_ptr<int64_t>(),
|
points_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
tris_first_idx.data_ptr<int64_t>(),
|
tris_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
B,
|
B,
|
||||||
@ -501,10 +501,10 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
|
|||||||
const int threads = 512;
|
const int threads = 512;
|
||||||
|
|
||||||
FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
FacePointBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
idx_tris.data_ptr<int64_t>(),
|
idx_tris.contiguous().data_ptr<int64_t>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_tris.data_ptr<float>(),
|
grad_tris.data_ptr<float>(),
|
||||||
T);
|
T);
|
||||||
@ -575,8 +575,8 @@ at::Tensor PointFaceArrayDistanceForwardCuda(
|
|||||||
const size_t threads = 64;
|
const size_t threads = 64;
|
||||||
|
|
||||||
PointFaceArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
PointFaceArrayForwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
dists.data_ptr<float>(),
|
dists.data_ptr<float>(),
|
||||||
P,
|
P,
|
||||||
T);
|
T);
|
||||||
@ -672,9 +672,9 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
|
|||||||
const size_t threads = 64;
|
const size_t threads = 64;
|
||||||
|
|
||||||
PointFaceArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
PointFaceArrayBackwardKernel<<<blocks, threads, 0, stream>>>(
|
||||||
points.data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
tris.data_ptr<float>(),
|
tris.contiguous().data_ptr<float>(),
|
||||||
grad_dists.data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_points.data_ptr<float>(),
|
grad_points.data_ptr<float>(),
|
||||||
grad_tris.data_ptr<float>(),
|
grad_tris.data_ptr<float>(),
|
||||||
P,
|
P,
|
||||||
|
@ -56,10 +56,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
|||||||
const int64_t max_points) {
|
const int64_t max_points) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
CHECK_CUDA(points_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
|
CHECK_CUDA(tris_first_idx);
|
||||||
return PointFaceDistanceForwardCuda(
|
return PointFaceDistanceForwardCuda(
|
||||||
points, points_first_idx, tris, tris_first_idx, max_points);
|
points, points_first_idx, tris, tris_first_idx, max_points);
|
||||||
#else
|
#else
|
||||||
@ -100,10 +100,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(idx_points);
|
CHECK_CUDA(idx_points);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
|
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -160,10 +160,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
|||||||
const int64_t max_tris) {
|
const int64_t max_tris) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(points_first_idx);
|
CHECK_CUDA(points_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris_first_idx);
|
CHECK_CUDA(tris_first_idx);
|
||||||
return FacePointDistanceForwardCuda(
|
return FacePointDistanceForwardCuda(
|
||||||
points, points_first_idx, tris, tris_first_idx, max_tris);
|
points, points_first_idx, tris, tris_first_idx, max_tris);
|
||||||
#else
|
#else
|
||||||
@ -204,10 +204,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(idx_tris);
|
CHECK_CUDA(idx_tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
|
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -250,8 +250,8 @@ torch::Tensor PointFaceArrayDistanceForward(
|
|||||||
const torch::Tensor& tris) {
|
const torch::Tensor& tris) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
return PointFaceArrayDistanceForwardCuda(points, tris);
|
return PointFaceArrayDistanceForwardCuda(points, tris);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
@ -285,9 +285,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(tris);
|
CHECK_CUDA(tris);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
|
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
|
@ -348,10 +348,10 @@ RasterizeMeshesNaiveCuda(
|
|||||||
H,
|
H,
|
||||||
W,
|
W,
|
||||||
K,
|
K,
|
||||||
face_idxs.contiguous().data_ptr<int64_t>(),
|
face_idxs.data_ptr<int64_t>(),
|
||||||
zbuf.contiguous().data_ptr<float>(),
|
zbuf.data_ptr<float>(),
|
||||||
pix_dists.contiguous().data_ptr<float>(),
|
pix_dists.data_ptr<float>(),
|
||||||
bary.contiguous().data_ptr<float>());
|
bary.data_ptr<float>());
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||||
@ -530,7 +530,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
|||||||
grad_zbuf.contiguous().data_ptr<float>(),
|
grad_zbuf.contiguous().data_ptr<float>(),
|
||||||
grad_bary.contiguous().data_ptr<float>(),
|
grad_bary.contiguous().data_ptr<float>(),
|
||||||
grad_dists.contiguous().data_ptr<float>(),
|
grad_dists.contiguous().data_ptr<float>(),
|
||||||
grad_face_verts.contiguous().data_ptr<float>());
|
grad_face_verts.data_ptr<float>());
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return grad_face_verts;
|
return grad_face_verts;
|
||||||
@ -727,8 +727,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
bin_size,
|
bin_size,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
M,
|
M,
|
||||||
faces_per_bin.contiguous().data_ptr<int32_t>(),
|
faces_per_bin.data_ptr<int32_t>(),
|
||||||
bin_faces.contiguous().data_ptr<int32_t>());
|
bin_faces.data_ptr<int32_t>());
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return bin_faces;
|
return bin_faces;
|
||||||
@ -897,10 +897,10 @@ RasterizeMeshesFineCuda(
|
|||||||
H,
|
H,
|
||||||
W,
|
W,
|
||||||
K,
|
K,
|
||||||
face_idxs.contiguous().data_ptr<int64_t>(),
|
face_idxs.data_ptr<int64_t>(),
|
||||||
zbuf.contiguous().data_ptr<float>(),
|
zbuf.data_ptr<float>(),
|
||||||
pix_dists.contiguous().data_ptr<float>(),
|
pix_dists.data_ptr<float>(),
|
||||||
bary.contiguous().data_ptr<float>());
|
bary.data_ptr<float>());
|
||||||
|
|
||||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||||
}
|
}
|
||||||
|
@ -96,9 +96,9 @@ RasterizeMeshesNaive(
|
|||||||
// TODO: Better type checking.
|
// TODO: Better type checking.
|
||||||
if (face_verts.is_cuda()) {
|
if (face_verts.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
CHECK_CUDA(face_verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
|
CHECK_CUDA(mesh_to_face_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
|
CHECK_CUDA(num_faces_per_mesh);
|
||||||
return RasterizeMeshesNaiveCuda(
|
return RasterizeMeshesNaiveCuda(
|
||||||
face_verts,
|
face_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
@ -179,11 +179,11 @@ torch::Tensor RasterizeMeshesBackward(
|
|||||||
const bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.is_cuda()) {
|
if (face_verts.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
CHECK_CUDA(face_verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(pix_to_face);
|
CHECK_CUDA(pix_to_face);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
|
CHECK_CUDA(grad_zbuf);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_bary);
|
CHECK_CUDA(grad_bary);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return RasterizeMeshesBackwardCuda(
|
return RasterizeMeshesBackwardCuda(
|
||||||
face_verts,
|
face_verts,
|
||||||
pix_to_face,
|
pix_to_face,
|
||||||
@ -260,9 +260,9 @@ torch::Tensor RasterizeMeshesCoarse(
|
|||||||
const int max_faces_per_bin) {
|
const int max_faces_per_bin) {
|
||||||
if (face_verts.is_cuda()) {
|
if (face_verts.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
CHECK_CUDA(face_verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(mesh_to_face_first_idx);
|
CHECK_CUDA(mesh_to_face_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(num_faces_per_mesh);
|
CHECK_CUDA(num_faces_per_mesh);
|
||||||
return RasterizeMeshesCoarseCuda(
|
return RasterizeMeshesCoarseCuda(
|
||||||
face_verts,
|
face_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
@ -359,8 +359,8 @@ RasterizeMeshesFine(
|
|||||||
const bool cull_backfaces) {
|
const bool cull_backfaces) {
|
||||||
if (face_verts.is_cuda()) {
|
if (face_verts.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(face_verts);
|
CHECK_CUDA(face_verts);
|
||||||
CHECK_CONTIGUOUS_CUDA(bin_faces);
|
CHECK_CUDA(bin_faces);
|
||||||
return RasterizeMeshesFineCuda(
|
return RasterizeMeshesFineCuda(
|
||||||
face_verts,
|
face_verts,
|
||||||
bin_faces,
|
bin_faces,
|
||||||
|
@ -67,9 +67,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
|||||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||||
num_points_per_cloud.is_cuda()) {
|
num_points_per_cloud.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
CHECK_CUDA(cloud_to_packed_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
CHECK_CUDA(num_points_per_cloud);
|
||||||
return RasterizePointsNaiveCuda(
|
return RasterizePointsNaiveCuda(
|
||||||
points,
|
points,
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
@ -144,9 +144,9 @@ torch::Tensor RasterizePointsCoarse(
|
|||||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||||
num_points_per_cloud.is_cuda()) {
|
num_points_per_cloud.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(cloud_to_packed_first_idx);
|
CHECK_CUDA(cloud_to_packed_first_idx);
|
||||||
CHECK_CONTIGUOUS_CUDA(num_points_per_cloud);
|
CHECK_CUDA(num_points_per_cloud);
|
||||||
return RasterizePointsCoarseCuda(
|
return RasterizePointsCoarseCuda(
|
||||||
points,
|
points,
|
||||||
cloud_to_packed_first_idx,
|
cloud_to_packed_first_idx,
|
||||||
@ -215,8 +215,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
|||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(bin_points);
|
CHECK_CUDA(bin_points);
|
||||||
return RasterizePointsFineCuda(
|
return RasterizePointsFineCuda(
|
||||||
points, bin_points, image_size, radius, bin_size, points_per_pixel);
|
points, bin_points, image_size, radius, bin_size, points_per_pixel);
|
||||||
#else
|
#else
|
||||||
@ -266,10 +266,10 @@ torch::Tensor RasterizePointsBackward(
|
|||||||
const torch::Tensor& grad_dists) {
|
const torch::Tensor& grad_dists) {
|
||||||
if (points.is_cuda()) {
|
if (points.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CONTIGUOUS_CUDA(points);
|
CHECK_CUDA(points);
|
||||||
CHECK_CONTIGUOUS_CUDA(idxs);
|
CHECK_CUDA(idxs);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_zbuf);
|
CHECK_CUDA(grad_zbuf);
|
||||||
CHECK_CONTIGUOUS_CUDA(grad_dists);
|
CHECK_CUDA(grad_dists);
|
||||||
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
|
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support");
|
AT_ERROR("Not compiled with GPU support");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user