diff --git a/pytorch3d/ops/graph_conv.py b/pytorch3d/ops/graph_conv.py index a43eeb3a..394266ee 100644 --- a/pytorch3d/ops/graph_conv.py +++ b/pytorch3d/ops/graph_conv.py @@ -114,6 +114,10 @@ def gather_scatter_python(input, edges, directed: bool = False): raise ValueError("edges can only have 2 dimensions.") if not (edges.shape[1] == 2): raise ValueError("edges must be of shape (num_edges, 2).") + if not (input.dtype == torch.float32): + raise ValueError("input has to be of type torch.float32.") + if not (edges.dtype == torch.int64): + raise ValueError("edges has to be of type torch.int64.") num_vertices, input_feature_dim = input.shape num_edges = edges.shape[0] @@ -152,6 +156,8 @@ class GatherScatter(Function): raise ValueError("edges must be of shape (num_edges, 2).") if not (input.dtype == torch.float32): raise ValueError("input has to be of type torch.float32.") + if not (edges.dtype == torch.int64): + raise ValueError("edges has to be of type torch.int64.") ctx.directed = directed input, edges = input.contiguous(), edges.contiguous()