mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Fix returning a proper rotation in levelling; supporting batches and default centroid
Summary: `get_rotation_to_best_fit_xy` is useful to expose externally, however there was a bug (which we probably did not care about for our use case): it could return a rotation matrix with det(R) == −1. The diff fixes that, and also makes centroid optional (it can be computed from points). Reviewed By: bottler Differential Revision: D39926791 fbshipit-source-id: 5120c7892815b829f3ddcc23e93d4a5ec0ca0013
This commit is contained in:
		
							parent
							
								
									de98c9cc2f
								
							
						
					
					
						commit
						74bbd6fd76
					
				@ -12,22 +12,30 @@ from typing import Optional
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_rotation_to_best_fit_xy(
 | 
			
		||||
    points: torch.Tensor, centroid: torch.Tensor
 | 
			
		||||
def get_rotation_to_best_fit_xy(
 | 
			
		||||
    points: torch.Tensor, centroid: Optional[torch.Tensor] = None
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Returns a rotation r such that points @ r has a best fit plane
 | 
			
		||||
    Returns a rotation R such that `points @ R` has a best fit plane
 | 
			
		||||
    parallel to the xy plane
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        points: (N, 3) tensor of points in 3D
 | 
			
		||||
        centroid: (3,) their centroid
 | 
			
		||||
        points: (*, N, 3) tensor of points in 3D
 | 
			
		||||
        centroid: (*, 1, 3), (3,) or scalar: their centroid
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (3,3) tensor rotation matrix
 | 
			
		||||
        (*, 3, 3) tensor rotation matrix
 | 
			
		||||
    """
 | 
			
		||||
    points_centered = points - centroid[None]
 | 
			
		||||
    return torch.linalg.eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
 | 
			
		||||
    if centroid is None:
 | 
			
		||||
        centroid = points.mean(dim=-2, keepdim=True)
 | 
			
		||||
 | 
			
		||||
    points_centered = points - centroid
 | 
			
		||||
    _, evec = torch.linalg.eigh(points_centered.transpose(-1, -2) @ points_centered)
 | 
			
		||||
    # in general, evec can form either right- or left-handed basis,
 | 
			
		||||
    # but we need the former to have a proper rotation (not reflection)
 | 
			
		||||
    return torch.cat(
 | 
			
		||||
        (evec[..., 1:], torch.cross(evec[..., 1], evec[..., 2])[..., None]), dim=-1
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _signed_area(path: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
@ -191,7 +199,7 @@ def fit_circle_in_3d(
 | 
			
		||||
        Circle3D object
 | 
			
		||||
    """
 | 
			
		||||
    centroid = points.mean(0)
 | 
			
		||||
    r = _get_rotation_to_best_fit_xy(points, centroid)
 | 
			
		||||
    r = get_rotation_to_best_fit_xy(points, centroid)
 | 
			
		||||
    normal = r[:, 2]
 | 
			
		||||
    rotated_points = (points - centroid) @ r
 | 
			
		||||
    result_2d = fit_circle_in_2d(
 | 
			
		||||
 | 
			
		||||
@ -12,8 +12,9 @@ from pytorch3d.implicitron.tools.circle_fitting import (
 | 
			
		||||
    _signed_area,
 | 
			
		||||
    fit_circle_in_2d,
 | 
			
		||||
    fit_circle_in_3d,
 | 
			
		||||
    get_rotation_to_best_fit_xy,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.transforms import random_rotation
 | 
			
		||||
from pytorch3d.transforms import random_rotation, random_rotations
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,6 +29,32 @@ class TestCircleFitting(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        """
 | 
			
		||||
        self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
 | 
			
		||||
 | 
			
		||||
    def test_plane_levelling(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        B = 16
 | 
			
		||||
        N = 1024
 | 
			
		||||
        random = torch.randn((B, N, 3), device=device)
 | 
			
		||||
 | 
			
		||||
        # first, check that we always return a vaild rotation
 | 
			
		||||
        rot = get_rotation_to_best_fit_xy(random)
 | 
			
		||||
        self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
 | 
			
		||||
        self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
 | 
			
		||||
 | 
			
		||||
        # then, check the result is what we expect
 | 
			
		||||
        z_squeeze = 0.1
 | 
			
		||||
        random[..., -1] *= z_squeeze
 | 
			
		||||
        rot_gt = random_rotations(B, device=device)
 | 
			
		||||
        rotated = random @ rot_gt.transpose(-1, -2)
 | 
			
		||||
        rot_hat = get_rotation_to_best_fit_xy(rotated)
 | 
			
		||||
        self.assertClose(rot.det(), torch.ones_like(rot[:, 0, 0]))
 | 
			
		||||
        self.assertClose(rot.norm(dim=-1), torch.ones_like(rot[:, 0]))
 | 
			
		||||
        # covariance matrix of the levelled points is by design diag(1, 1, z_squeeze²)
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            (rotated @ rot_hat)[..., -1].std(dim=-1),
 | 
			
		||||
            torch.ones_like(rot_hat[:, 0, 0]) * z_squeeze,
 | 
			
		||||
            rtol=0.1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_simple_3d(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        for _ in range(7):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user