mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-09-16 01:22:48 +08:00
fix device error
Summary: When using `sample_farthest_points` with `lengths`, it throws an error because of the device mismatch between `lengths` and `torch.rand(lengths.size())` on GPU. Reviewed By: bottler Differential Revision: D82378997 fbshipit-source-id: 8e929256177d543d1dd1249e8488f70e03e4101f
This commit is contained in:
parent
d098beb7a7
commit
7711bf34a8
@ -89,7 +89,9 @@ def sample_farthest_points(
|
|||||||
if constant_length:
|
if constant_length:
|
||||||
start_idxs = torch.randint(high=P, size=(N,), device=device)
|
start_idxs = torch.randint(high=P, size=(N,), device=device)
|
||||||
else:
|
else:
|
||||||
start_idxs = (lengths * torch.rand(lengths.size())).to(torch.int64)
|
start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
|
||||||
|
torch.int64
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
start_idxs = torch.zeros_like(lengths)
|
start_idxs = torch.zeros_like(lengths)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user