Pointclouds.subsample on Windows

Summary: Fix https://github.com/facebookresearch/pytorch3d/issues/1015. Stop relying on the fact that the dtype returned by np.random.choice (int64 on Linux, int32 on Windows) matches the dtype used by pytorch for indexing (int64 everywhere).

Reviewed By: patricklabatut

Differential Revision: D33428680

fbshipit-source-id: 716c857502cd54c563cb256f0eaca7dccd535c10
This commit is contained in:
Jeremy Reizenstein 2022-01-06 02:21:48 -08:00 committed by Facebook GitHub Bot
parent 49f93b6388
commit d6a12afbe7

View File

@ -891,7 +891,7 @@ class Pointclouds:
): ):
if n_points > max_: if n_points > max_:
keep_np = np.random.choice(n_points, max_, replace=False) keep_np = np.random.choice(n_points, max_, replace=False)
keep = torch.tensor(keep_np).to(points.device) keep = torch.tensor(keep_np, device=points.device, dtype=torch.int64)
points = points[keep] points = points[keep]
if features is not None: if features is not None:
features = features[keep] features = features[keep]