mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Revert _sqrt_positive_part
change
Reviewed By: bottler Differential Revision: D77549647 fbshipit-source-id: a0ef0bc015c643ad7416c781886e2e23b5105bdd
This commit is contained in:
parent
177eec6378
commit
267bd8ef87
@ -95,7 +95,13 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
return torch.sqrt(F.relu(x))
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
|
Loading…
x
Reference in New Issue
Block a user