From 7711bf34a86bef32556eee09bc8ca715f677cd6b Mon Sep 17 00:00:00 2001 From: Kihyuk Sohn Date: Mon, 15 Sep 2025 06:41:00 -0700 Subject: [PATCH] fix device error Summary: When using `sample_farthest_points` with `lengths`, it throws an error because of the device mismatch between `lengths` and `torch.rand(lengths.size())` on GPU. Reviewed By: bottler Differential Revision: D82378997 fbshipit-source-id: 8e929256177d543d1dd1249e8488f70e03e4101f --- pytorch3d/ops/sample_farthest_points.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index 999aa0e4..15324964 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -89,7 +89,9 @@ def sample_farthest_points( if constant_length: start_idxs = torch.randint(high=P, size=(N,), device=device) else: - start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64) + start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to( + torch.int64 + ) else: start_idxs = torch.zeros_like(lengths)