mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-07 04:36:00 +08:00
Make _sqrt_positive_part ONNX-exportable
Summary: Replace boolean indexing and torch.is_grad_enabled() control flow in _sqrt_positive_part with a pure torch.where implementation. The old code used ret[positive_mask] = torch.sqrt(x[positive_mask]) which produces an incorrect ONNX Where/index_put node with mismatched broadcast shapes when the model is exported via torch.onnx.export. The new implementation substitutes 1.0 for non-positive values before sqrt (avoiding infinite gradient at sqrt(0)) and masks the result back to 0, preserving the zero-subgradient-at-zero property. Fixes https://github.com/facebookresearch/pytorch3d/issues/2020 Reviewed By: sgrigory Differential Revision: D94365479 fbshipit-source-id: a1ebe8dc077573f83efc262520b6669159b83ef0
This commit is contained in:
committed by
meta-codesync[bot]
parent
7a6157e38e
commit
61cc79aa34
@@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
safe_x = torch.where(positive_mask, x, 1.0)
|
||||
return torch.where(positive_mask, torch.sqrt(safe_x), 0.0)
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user