allow matrix_to_quaternion onnx export

Summary: Attempt to allow torch.onnx.dynamo_export(matrix_to_quaternion) to work.

Differential Revision: D59812279

fbshipit-source-id: 4497e5b543bec9d5c2bdccfb779d154750a075ad
This commit is contained in:
Jeremy Reizenstein 2024-07-16 11:30:20 -07:00 committed by Facebook GitHub Bot
parent d0d0e02007
commit 7edaee71a9

View File

@ -97,7 +97,10 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
""" """
ret = torch.zeros_like(x) ret = torch.zeros_like(x)
positive_mask = x > 0 positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask]) if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret return ret