mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Remove pytorch3d's wrappers for eigh, solve, lstsq, qr
Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly. Reviewed By: bottler Differential Revision: D39172949 fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9a1213e0e5
commit
d4a1051e0f
@@ -9,7 +9,6 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import lstsq
|
||||
from pytorch3d.transforms import acos_linear_extrapolation
|
||||
|
||||
from .common_testing import TestCaseMixin
|
||||
@@ -66,7 +65,7 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
|
||||
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
|
||||
# fit a line: slope * x + bias = y
|
||||
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
|
||||
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
|
||||
slope, bias = torch.linalg.lstsq(x_1, y[:, None]).solution.view(-1)[:2]
|
||||
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t**2)
|
||||
# test that the desired slope is the same as the fitted one
|
||||
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
|
||||
|
||||
@@ -9,7 +9,6 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import qr
|
||||
from pytorch3d.transforms.rotation_conversions import random_rotations
|
||||
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
|
||||
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
|
||||
@@ -199,7 +198,7 @@ class TestSE3(TestCaseMixin, unittest.TestCase):
|
||||
r = [identity, rot180]
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-6)[0]
|
||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
|
||||
# this adds random noise to the second half
|
||||
# of the random orthogonal matrices to generate
|
||||
|
||||
@@ -11,7 +11,6 @@ from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.common.compat import qr
|
||||
from pytorch3d.transforms.so3 import (
|
||||
hat,
|
||||
so3_exp_map,
|
||||
@@ -49,7 +48,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
# TODO(dnovotny): replace with random_rotation from random_rotation.py
|
||||
rot = []
|
||||
for _ in range(batch_size):
|
||||
r = qr(torch.randn((3, 3), device=device))[0]
|
||||
r = torch.linalg.qr(torch.randn((3, 3), device=device))[0]
|
||||
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
|
||||
if f.sum() % 2 == 0:
|
||||
f = 1 - f
|
||||
@@ -145,7 +144,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
# add random rotations and random almost orthonormal matrices
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
|
||||
# this adds random noise to the second half
|
||||
# of the random orthogonal matrices to generate
|
||||
@@ -245,7 +244,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
r = [identity, rot180]
|
||||
r.extend(
|
||||
[
|
||||
qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
|
||||
for _ in range(batch_size - 2)
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user