mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-06-17 04:28:54 +08:00
validate input and edges dtype in GatherScatter python wrapper
Summary: Add explicit dtype checks for input (torch.float32) and edges (torch.int64) in GatherScatter.forward and gather_scatter_python to match C++ TensorAccessor<float,2> and TensorAccessor<int64_t,2> expectations. Python previously validated ndim, shape, and input dtype in forward but not edges dtype, and gather_scatter_python lacked dtype checks entirely, relying on ATen error from accessor. This makes errors python-friendly and guards C++ accessor before TensorAccessor construction. ___ Differential Revision: D108140422 fbshipit-source-id: ba54e857279a480a02e2c8f27e316f2e23cc6092
This commit is contained in:
committed by
meta-codesync[bot]
parent
1f7f85c0a3
commit
d34e87ce52
@@ -114,6 +114,10 @@ def gather_scatter_python(input, edges, directed: bool = False):
|
||||
raise ValueError("edges can only have 2 dimensions.")
|
||||
if not (edges.shape[1] == 2):
|
||||
raise ValueError("edges must be of shape (num_edges, 2).")
|
||||
if not (input.dtype == torch.float32):
|
||||
raise ValueError("input has to be of type torch.float32.")
|
||||
if not (edges.dtype == torch.int64):
|
||||
raise ValueError("edges has to be of type torch.int64.")
|
||||
|
||||
num_vertices, input_feature_dim = input.shape
|
||||
num_edges = edges.shape[0]
|
||||
@@ -152,6 +156,8 @@ class GatherScatter(Function):
|
||||
raise ValueError("edges must be of shape (num_edges, 2).")
|
||||
if not (input.dtype == torch.float32):
|
||||
raise ValueError("input has to be of type torch.float32.")
|
||||
if not (edges.dtype == torch.int64):
|
||||
raise ValueError("edges has to be of type torch.int64.")
|
||||
|
||||
ctx.directed = directed
|
||||
input, edges = input.contiguous(), edges.contiguous()
|
||||
|
||||
Reference in New Issue
Block a user