Correct "fast" matrix_to_axis_angle near pi (#1953)

Summary:
A continuation of https://github.com/facebookresearch/pytorch3d/issues/1948 -- this commit fixes a small numerical issue with `matrix_to_axis_angle(..., fast=True)` near `pi`.
bottler feel free to check this out, it's a single-line change.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1953

Reviewed By: MichaelRamamonjisoa

Differential Revision: D70088251

Pulled By: bottler

fbshipit-source-id: 54cc7f946283db700cec2cd5575cf918456b7f32
This commit is contained in:
Alexandros Benetatos 2025-03-11 12:25:59 -07:00 committed by Facebook GitHub Bot
parent 21205730d9
commit 06a76ef8dd

View File

@ -542,9 +542,7 @@ def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tens
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
-1
)
near_pi = angles.isclose(angles.new_full((1,), torch.pi)).squeeze(-1)
axis_angles = torch.empty_like(omegas)
axis_angles[~near_pi] = (