diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index 999aa0e4..15324964 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -89,7 +89,9 @@ def sample_farthest_points( if constant_length: start_idxs = torch.randint(high=P, size=(N,), device=device) else: - start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64) + start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to( + torch.int64 + ) else: start_idxs = torch.zeros_like(lengths)