Remove pytorch3d's wrappers for eigh, solve, lstsq, qr

Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly.

Reviewed By: bottler

Differential Revision: D39172949

fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
This commit is contained in:
Chris Lambert
2022-08-31 13:04:07 -07:00
committed by Facebook GitHub Bot
parent 9a1213e0e5
commit d4a1051e0f
9 changed files with 11 additions and 67 deletions

View File

@@ -11,7 +11,6 @@ from distutils.version import LooseVersion
import numpy as np
import torch
from pytorch3d.common.compat import qr
from pytorch3d.transforms.so3 import (
hat,
so3_exp_map,
@@ -49,7 +48,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# TODO(dnovotny): replace with random_rotation from random_rotation.py
rot = []
for _ in range(batch_size):
r = qr(torch.randn((3, 3), device=device))[0]
r = torch.linalg.qr(torch.randn((3, 3), device=device))[0]
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
if f.sum() % 2 == 0:
f = 1 - f
@@ -145,7 +144,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# add random rotations and random almost orthonormal matrices
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
@@ -245,7 +244,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
r = [identity, rot180]
r.extend(
[
qr(identity + torch.randn_like(identity) * 1e-4)[0]
torch.linalg.qr(identity + torch.randn_like(identity) * 1e-4)[0]
for _ in range(batch_size - 2)
]
)