mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Removing dynamic shape ops and boolean indexing in matrix_to_quaternion
Summary: The current implementation of `matrix_to_quaternion` and `_sqrt_positive_part` uses boolean indexing, which can slow down performance and cause incompatibility with `torch.compile` unless `torch._dynamo.config.capture_dynamic_output_shape_ops` is set to `True`. To enhance performance and compatibility, I recommend using `torch.gather` to select the best-conditioned quaternions and `F.relu` instead of `x>0` (bottler's suggestion) For a detailed comparison of the implementation differences when using `torch.compile`, please refer to my Bento notebook N7438339. Reviewed By: bottler Differential Revision: D77176230 fbshipit-source-id: 9a6a2e0015b5865056297d5f45badc3c425b93ce
This commit is contained in:
parent
6020323d94
commit
71db7a0ea2
@ -95,13 +95,7 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
return torch.sqrt(F.relu(x))
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
@ -160,9 +154,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
out = quat_candidates[
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
||||
].reshape(batch_dim + (4,))
|
||||
indices = q_abs.argmax(dim=-1, keepdim=True)
|
||||
expand_dims = list(batch_dim) + [1, 4]
|
||||
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
|
||||
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user