mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Removing dynamic shape ops and boolean indexing in matrix_to_quaternion
Summary: The current implementation of `matrix_to_quaternion` and `_sqrt_positive_part` uses boolean indexing, which can slow down performance and cause incompatibility with `torch.compile` unless `torch._dynamo.config.capture_dynamic_output_shape_ops` is set to `True`. To enhance performance and compatibility, I recommend using `torch.gather` to select the best-conditioned quaternions and `F.relu` instead of `x>0` (bottler's suggestion) For a detailed comparison of the implementation differences when using `torch.compile`, please refer to my Bento notebook N7438339. Reviewed By: bottler Differential Revision: D77176230 fbshipit-source-id: 9a6a2e0015b5865056297d5f45badc3c425b93ce
This commit is contained in:
parent
6020323d94
commit
71db7a0ea2
@ -95,13 +95,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns torch.sqrt(torch.max(0, x))
|
Returns torch.sqrt(torch.max(0, x))
|
||||||
but with a zero subgradient where x is 0.
|
but with a zero subgradient where x is 0.
|
||||||
"""
|
"""
|
||||||
ret = torch.zeros_like(x)
|
return torch.sqrt(F.relu(x))
|
||||||
positive_mask = x > 0
|
|
||||||
if torch.is_grad_enabled():
|
|
||||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
|
||||||
else:
|
|
||||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||||
@ -160,9 +154,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||||
out = quat_candidates[
|
indices = q_abs.argmax(dim=-1, keepdim=True)
|
||||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
expand_dims = list(batch_dim) + [1, 4]
|
||||||
].reshape(batch_dim + (4,))
|
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
|
||||||
|
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
|
||||||
return standardize_quaternion(out)
|
return standardize_quaternion(out)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user