mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-25 23:56:00 +08:00
fbcode/vision/fair/pytorch3d/pytorch3d/transforms/rotation_conversions.py
Reviewed By: bottler Differential Revision: D93712828 fbshipit-source-id: 3465af450104bb1e5f491e3c0ee0259698cf8ceb
This commit is contained in:
committed by
meta-codesync[bot]
parent
49f43402c6
commit
e43ed8c76e
@@ -52,8 +52,7 @@ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
r, i, j, k = torch.unbind(quaternions, -1)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
two_s = torch.div(2.0, (quaternions * quaternions).sum(-1))
|
||||
|
||||
o = torch.stack(
|
||||
(
|
||||
@@ -137,18 +136,18 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
torch.stack(
|
||||
[torch.square(q_abs[..., 0]), m21 - m12, m02 - m20, m10 - m01], dim=-1
|
||||
),
|
||||
torch.stack(
|
||||
[m21 - m12, torch.square(q_abs[..., 1]), m10 + m01, m02 + m20], dim=-1
|
||||
),
|
||||
torch.stack(
|
||||
[m02 - m20, m10 + m01, torch.square(q_abs[..., 2]), m12 + m21], dim=-1
|
||||
),
|
||||
torch.stack(
|
||||
[m10 - m01, m20 + m02, m21 + m12, torch.square(q_abs[..., 3])], dim=-1
|
||||
),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user