validate input and edges dtype in GatherScatter python wrapper

Summary:
Add explicit dtype checks for input (torch.float32) and edges (torch.int64) in GatherScatter.forward and gather_scatter_python to match C++ TensorAccessor<float,2> and TensorAccessor<int64_t,2> expectations.

Python previously validated ndim, shape, and input dtype in forward but not edges dtype, and gather_scatter_python lacked dtype checks entirely, relying on ATen error from accessor. This makes errors python-friendly and guards C++ accessor before TensorAccessor construction.

___

Differential Revision: D108140422

fbshipit-source-id: ba54e857279a480a02e2c8f27e316f2e23cc6092
This commit is contained in:
Jeremy Reizenstein
2026-06-10 11:29:55 -07:00
committed by meta-codesync[bot]
parent 1f7f85c0a3
commit d34e87ce52

View File

@@ -114,6 +114,10 @@ def gather_scatter_python(input, edges, directed: bool = False):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
if not (edges.dtype == torch.int64):
raise ValueError("edges has to be of type torch.int64.")
num_vertices, input_feature_dim = input.shape
num_edges = edges.shape[0]
@@ -152,6 +156,8 @@ class GatherScatter(Function):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
if not (edges.dtype == torch.int64):
raise ValueError("edges has to be of type torch.int64.")
ctx.directed = directed
input, edges = input.contiguous(), edges.contiguous()