mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Revert _sqrt_positive_part
change
Reviewed By: bottler Differential Revision: D77549647 fbshipit-source-id: a0ef0bc015c643ad7416c781886e2e23b5105bdd
This commit is contained in:
parent
177eec6378
commit
267bd8ef87
@ -95,7 +95,13 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns torch.sqrt(torch.max(0, x))
|
Returns torch.sqrt(torch.max(0, x))
|
||||||
but with a zero subgradient where x is 0.
|
but with a zero subgradient where x is 0.
|
||||||
"""
|
"""
|
||||||
return torch.sqrt(F.relu(x))
|
ret = torch.zeros_like(x)
|
||||||
|
positive_mask = x > 0
|
||||||
|
if torch.is_grad_enabled():
|
||||||
|
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||||
|
else:
|
||||||
|
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user