mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Fix returning a proper rotation in levelling; supporting batches and default centroid
Summary: `get_rotation_to_best_fit_xy` is useful to expose externally, however there was a bug (which we probably did not care about for our use case): it could return a rotation matrix with det(R) == −1. The diff fixes that, and also makes centroid optional (it can be computed from points). Reviewed By: bottler Differential Revision: D39926791 fbshipit-source-id: 5120c7892815b829f3ddcc23e93d4a5ec0ca0013
This commit is contained in:
committed by
Facebook GitHub Bot
parent
de98c9cc2f
commit
74bbd6fd76
@@ -12,22 +12,30 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
def _get_rotation_to_best_fit_xy(
|
||||
points: torch.Tensor, centroid: torch.Tensor
|
||||
def get_rotation_to_best_fit_xy(
|
||||
points: torch.Tensor, centroid: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a rotation r such that points @ r has a best fit plane
|
||||
Returns a rotation R such that `points @ R` has a best fit plane
|
||||
parallel to the xy plane
|
||||
|
||||
Args:
|
||||
points: (N, 3) tensor of points in 3D
|
||||
centroid: (3,) their centroid
|
||||
points: (*, N, 3) tensor of points in 3D
|
||||
centroid: (*, 1, 3), (3,) or scalar: their centroid
|
||||
|
||||
Returns:
|
||||
(3,3) tensor rotation matrix
|
||||
(*, 3, 3) tensor rotation matrix
|
||||
"""
|
||||
points_centered = points - centroid[None]
|
||||
return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
|
||||
if centroid is None:
|
||||
centroid = points.mean(dim=-2, keepdim=True)
|
||||
|
||||
points_centered = points - centroid
|
||||
_, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered)
|
||||
# in general, evec can form either right- or left-handed basis,
|
||||
# but we need the former to have a proper rotation (not reflection)
|
||||
return torch.cat(
|
||||
(evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1
|
||||
)
|
||||
|
||||
|
||||
def _signed_area(path: torch.Tensor) -> torch.Tensor:
|
||||
@@ -191,7 +199,7 @@ def fit_circle_in_3d(
|
||||
Circle3D object
|
||||
"""
|
||||
centroid = points.mean(0)
|
||||
r = _get_rotation_to_best_fit_xy(points, centroid)
|
||||
r = get_rotation_to_best_fit_xy(points, centroid)
|
||||
normal = r[:, 2]
|
||||
rotated_points = (points - centroid) @ r
|
||||
result_2d = fit_circle_in_2d(
|
||||
|
||||
Reference in New Issue
Block a user