diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 09cdc72b..837670de 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -383,7 +383,7 @@ class CamerasBase(TensorProperties): kwargs = {} - if not isinstance(index, (int, list, torch.LongTensor)): + if not isinstance(index, (int, list, torch.LongTensor, torch.cuda.LongTensor)): msg = "Invalid index type, expected int, List[int] or torch.LongTensor; got %r" raise ValueError(msg % type(index))