fix device error

Summary: When using `sample_farthest_points` with `lengths`, it throws an error because of the device mismatch between `lengths` and `torch.rand(lengths.size())` on GPU.

Reviewed By: bottler

Differential Revision: D82378997

fbshipit-source-id: 8e929256177d543d1dd1249e8488f70e03e4101f
This commit is contained in:
Kihyuk Sohn 2025-09-15 06:41:00 -07:00 committed by Facebook GitHub Bot
parent d098beb7a7
commit 7711bf34a8

View File

@ -89,7 +89,9 @@ def sample_farthest_points(
if constant_length: if constant_length:
start_idxs = torch.randint(high=P, size=(N,), device=device) start_idxs = torch.randint(high=P, size=(N,), device=device)
else: else:
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64) start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
torch.int64
)
else: else:
start_idxs = torch.zeros_like(lengths) start_idxs = torch.zeros_like(lengths)