Make cuda tensors contiguous in host function and remove contiguous check

Summary:
Update the cuda kernels to:
- remove contiguous checks for the grad tensors and for cpu functions which use accessors
- for cuda implementations call `.contiguous()` on all tensors in the host function before invoking the kernel

Reviewed By: gkioxari

Differential Revision: D21598008

fbshipit-source-id: 9b97bda4582fd4269c8a00999874d4552a1aea2d
This commit is contained in:
Nikhila Ravi
2020-05-15 14:58:04 -07:00
committed by Facebook GitHub Bot
parent a8377f1f06
commit 3fef506895
21 changed files with 219 additions and 233 deletions

View File

@@ -348,10 +348,10 @@ RasterizeMeshesNaiveCuda(
H,
W,
K,
face_idxs.contiguous().data_ptr<int64_t>(),
zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>(),
bary.contiguous().data_ptr<float>());
face_idxs.data_ptr<int64_t>(),
zbuf.data_ptr<float>(),
pix_dists.data_ptr<float>(),
bary.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
@@ -530,7 +530,7 @@ at::Tensor RasterizeMeshesBackwardCuda(
grad_zbuf.contiguous().data_ptr<float>(),
grad_bary.contiguous().data_ptr<float>(),
grad_dists.contiguous().data_ptr<float>(),
grad_face_verts.contiguous().data_ptr<float>());
grad_face_verts.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
return grad_face_verts;
@@ -727,8 +727,8 @@ at::Tensor RasterizeMeshesCoarseCuda(
bin_size,
chunk_size,
M,
faces_per_bin.contiguous().data_ptr<int32_t>(),
bin_faces.contiguous().data_ptr<int32_t>());
faces_per_bin.data_ptr<int32_t>(),
bin_faces.data_ptr<int32_t>());
AT_CUDA_CHECK(cudaGetLastError());
return bin_faces;
@@ -897,10 +897,10 @@ RasterizeMeshesFineCuda(
H,
W,
K,
face_idxs.contiguous().data_ptr<int64_t>(),
zbuf.contiguous().data_ptr<float>(),
pix_dists.contiguous().data_ptr<float>(),
bary.contiguous().data_ptr<float>());
face_idxs.data_ptr<int64_t>(),
zbuf.data_ptr<float>(),
pix_dists.data_ptr<float>(),
bary.data_ptr<float>());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}