mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
avoid converting a TensorOptions from float to integer
Summary: pytorch is adding checks that mean integer tensors with requires_grad=True need to be avoided. Fix accidentally creating them. Reviewed By: jcjohnson, gkioxari Differential Revision: D21576712 fbshipit-source-id: 008218997986800a36d93caa1a032ee91f2bffcd
This commit is contained in:
parent
6a365d203f
commit
728179e848
@ -319,7 +319,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const int64_t K_64 = K;
|
||||
|
||||
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = p1.options().dtype(at::kLong);
|
||||
auto long_dtype = lengths1.options().dtype(at::kLong);
|
||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||
auto dists = at::zeros({N, P1, K}, p1.options());
|
||||
|
||||
|
@ -14,7 +14,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
|
||||
auto long_opts = p1.options().dtype(torch::kInt64);
|
||||
auto long_opts = lengths1.options().dtype(torch::kInt64);
|
||||
torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
|
||||
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
|
||||
|
||||
|
@ -321,7 +321,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const int W = image_size;
|
||||
const int K = num_closest;
|
||||
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto long_opts = num_faces_per_mesh.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
@ -701,7 +701,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = face_verts.options().dtype(at::kInt);
|
||||
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
|
||||
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
@ -868,7 +868,7 @@ RasterizeMeshesFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto long_opts = face_verts.options().dtype(at::kLong);
|
||||
auto long_opts = bin_faces.options().dtype(at::kLong);
|
||||
auto float_opts = face_verts.options().dtype(at::kFloat);
|
||||
|
||||
at::Tensor face_idxs = at::full({N, H, W, K}, -1, long_opts);
|
||||
|
@ -123,7 +123,7 @@ RasterizeMeshesNaiveCpu(
|
||||
const int W = image_size;
|
||||
const int K = faces_per_pixel;
|
||||
|
||||
auto long_opts = face_verts.options().dtype(torch::kInt64);
|
||||
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
|
||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||
|
||||
// Initialize output tensors.
|
||||
@ -418,7 +418,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
const int BH = 1 + (height - 1) / bin_size; // Integer division round up.
|
||||
const int BW = 1 + (width - 1) / bin_size; // Integer division round up.
|
||||
|
||||
auto opts = face_verts.options().dtype(torch::kInt32);
|
||||
auto opts = num_faces_per_mesh.options().dtype(torch::kInt32);
|
||||
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
|
||||
torch::Tensor bin_faces = torch::full({N, BH, BW, M}, -1, opts);
|
||||
auto bin_faces_a = bin_faces.accessor<int32_t, 4>();
|
||||
|
@ -177,7 +177,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto int_opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
@ -372,7 +372,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
ss << "Got " << num_bins << "; that's too many!";
|
||||
AT_ERROR(ss.str());
|
||||
}
|
||||
auto opts = points.options().dtype(at::kInt);
|
||||
auto opts = num_points_per_cloud.options().dtype(at::kInt);
|
||||
at::Tensor points_per_bin = at::zeros({N, num_bins, num_bins}, opts);
|
||||
at::Tensor bin_points = at::full({N, num_bins, num_bins, M}, -1, opts);
|
||||
|
||||
@ -509,7 +509,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
if (K > kMaxPointsPerPixel) {
|
||||
AT_ERROR("Must have num_closest <= 150");
|
||||
}
|
||||
auto int_opts = points.options().dtype(at::kInt);
|
||||
auto int_opts = bin_points.options().dtype(at::kInt);
|
||||
auto float_opts = points.options().dtype(at::kFloat);
|
||||
at::Tensor point_idxs = at::full({N, S, S, K}, -1, int_opts);
|
||||
at::Tensor zbuf = at::full({N, S, S, K}, -1, float_opts);
|
||||
|
@ -25,7 +25,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const int K = points_per_pixel;
|
||||
|
||||
// Initialize output tensors.
|
||||
auto int_opts = points.options().dtype(torch::kInt32);
|
||||
auto int_opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||
auto float_opts = points.options().dtype(torch::kFloat32);
|
||||
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
|
||||
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
|
||||
@ -105,7 +105,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
|
||||
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
|
||||
const int M = max_points_per_bin;
|
||||
auto opts = points.options().dtype(torch::kInt32);
|
||||
auto opts = num_points_per_cloud.options().dtype(torch::kInt32);
|
||||
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
|
||||
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
|
||||
|
||||
|
@ -67,6 +67,7 @@ class _knn_points(Function):
|
||||
idx = idx.gather(2, sort_idx)
|
||||
|
||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return dists, idx
|
||||
|
||||
@staticmethod
|
||||
|
@ -197,6 +197,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
cull_backfaces,
|
||||
)
|
||||
ctx.save_for_backward(face_verts, pix_to_face)
|
||||
ctx.mark_non_differentiable(pix_to_face)
|
||||
ctx.perspective_correct = perspective_correct
|
||||
return pix_to_face, zbuf, barycentric_coords, dists
|
||||
|
||||
|
@ -140,6 +140,7 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
)
|
||||
idx, zbuf, dists = _C.rasterize_points(*args)
|
||||
ctx.save_for_backward(points, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return idx, zbuf, dists
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user