From d34e87ce5266634ffd2b155ffd08ba02542fa009 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 10 Jun 2026 11:29:55 -0700 Subject: [PATCH] validate input and edges dtype in GatherScatter python wrapper Summary: Add explicit dtype checks for input (torch.float32) and edges (torch.int64) in GatherScatter.forward and gather_scatter_python to match C++ TensorAccessor and TensorAccessor expectations. Python previously validated ndim, shape, and input dtype in forward but not edges dtype, and gather_scatter_python lacked dtype checks entirely, relying on ATen error from accessor. This makes errors python-friendly and guards C++ accessor before TensorAccessor construction. ___ Differential Revision: D108140422 fbshipit-source-id: ba54e857279a480a02e2c8f27e316f2e23cc6092 --- pytorch3d/ops/graph_conv.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch3d/ops/graph_conv.py b/pytorch3d/ops/graph_conv.py index a43eeb3a..394266ee 100644 --- a/pytorch3d/ops/graph_conv.py +++ b/pytorch3d/ops/graph_conv.py @@ -114,6 +114,10 @@ def gather_scatter_python(input, edges, directed: bool = False): raise ValueError("edges can only have 2 dimensions.") if not (edges.shape[1] == 2): raise ValueError("edges must be of shape (num_edges, 2).") + if not (input.dtype == torch.float32): + raise ValueError("input has to be of type torch.float32.") + if not (edges.dtype == torch.int64): + raise ValueError("edges has to be of type torch.int64.") num_vertices, input_feature_dim = input.shape num_edges = edges.shape[0] @@ -152,6 +156,8 @@ class GatherScatter(Function): raise ValueError("edges must be of shape (num_edges, 2).") if not (input.dtype == torch.float32): raise ValueError("input has to be of type torch.float32.") + if not (edges.dtype == torch.int64): + raise ValueError("edges has to be of type torch.int64.") ctx.directed = directed input, edges = input.contiguous(), edges.contiguous()