mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-12 07:15:58 +08:00
Make cuda tensors contiguous in host function and remove contiguous check
Summary: Update the cuda kernels to: - remove contiguous checks for the grad tensors and for cpu functions which use accessors - for cuda implementations call `.contiguous()` on all tensors in the host function before invoking the kernel Reviewed By: gkioxari Differential Revision: D21598008 fbshipit-source-id: 9b97bda4582fd4269c8a00999874d4552a1aea2d
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a8377f1f06
commit
3fef506895
@@ -348,10 +348,10 @@ RasterizeMeshesNaiveCuda(
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
face_idxs.contiguous().data_ptr<int64_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>(),
|
||||
bary.contiguous().data_ptr<float>());
|
||||
face_idxs.data_ptr<int64_t>(),
|
||||
zbuf.data_ptr<float>(),
|
||||
pix_dists.data_ptr<float>(),
|
||||
bary.data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
@@ -530,7 +530,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
|
||||
grad_zbuf.contiguous().data_ptr<float>(),
|
||||
grad_bary.contiguous().data_ptr<float>(),
|
||||
grad_dists.contiguous().data_ptr<float>(),
|
||||
grad_face_verts.contiguous().data_ptr<float>());
|
||||
grad_face_verts.data_ptr<float>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_face_verts;
|
||||
@@ -727,8 +727,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
bin_size,
|
||||
chunk_size,
|
||||
M,
|
||||
faces_per_bin.contiguous().data_ptr<int32_t>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>());
|
||||
faces_per_bin.data_ptr<int32_t>(),
|
||||
bin_faces.data_ptr<int32_t>());
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return bin_faces;
|
||||
@@ -897,10 +897,10 @@ RasterizeMeshesFineCuda(
|
||||
H,
|
||||
W,
|
||||
K,
|
||||
face_idxs.contiguous().data_ptr<int64_t>(),
|
||||
zbuf.contiguous().data_ptr<float>(),
|
||||
pix_dists.contiguous().data_ptr<float>(),
|
||||
bary.contiguous().data_ptr<float>());
|
||||
face_idxs.data_ptr<int64_t>(),
|
||||
zbuf.data_ptr<float>(),
|
||||
pix_dists.data_ptr<float>(),
|
||||
bary.data_ptr<float>());
|
||||
|
||||
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user