work with old linalg

Summary: solve and lstsq have moved around in torch. Cope with both.

Reviewed By: patricklabatut

Differential Revision: D29302316

fbshipit-source-id: b34f0b923e90a357f20df359635929241eba6e74
This commit is contained in:
Jeremy Reizenstein
2021-06-28 06:30:27 -07:00
committed by Facebook GitHub Bot
parent 5284de6e97
commit b8790474f1
7 changed files with 65 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.compat import lstsq
from pytorch3d.transforms import acos_linear_extrapolation
@@ -64,8 +65,7 @@ class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
# fit a line: slope * x + bias = y
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
solution = torch.linalg.lstsq(x_1, y[:, None]).solution
slope, bias = solution.view(-1)[:2]
slope, bias = lstsq(x_1, y[:, None]).view(-1)[:2]
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
# test that the desired slope is the same as the fitted one
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)

View File

@@ -10,13 +10,10 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.compat import qr
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
from pytorch3d.transforms.so3 import (
so3_exp_map,
so3_log_map,
so3_rotation_angle,
)
from pytorch3d.transforms.so3 import so3_exp_map, so3_log_map, so3_rotation_angle
class TestSE3(TestCaseMixin, unittest.TestCase):
@@ -201,7 +198,7 @@ class TestSE3(TestCaseMixin, unittest.TestCase):
r = [identity, rot180]
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-6)[0]
qr(identity + torch.randn_like(identity) * 1e-6)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
# this adds random noise to the second half
# of the random orthogonal matrices to generate

View File

@@ -11,6 +11,7 @@ import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.compat import qr
from pytorch3d.transforms.so3 import (
hat,
so3_exp_map,
@@ -46,7 +47,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# TODO(dnovotny): replace with random_rotation from random_rotation.py
rot = []
for _ in range(batch_size):
r = torch.qr(torch.randn((3, 3), device=device))[0]
r = qr(torch.randn((3, 3), device=device))[0]
f = torch.randint(2, (3,), device=device, dtype=torch.float32)
if f.sum() % 2 == 0:
f = 1 - f
@@ -142,7 +143,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# add random rotations and random almost orthonormal matrices
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
qr(identity + torch.randn_like(identity) * 1e-4)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-3
# this adds random noise to the second half
# of the random orthogonal matrices to generate
@@ -242,7 +243,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
r = [identity, rot180]
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-4)[0]
qr(identity + torch.randn_like(identity) * 1e-4)[0]
for _ in range(batch_size - 2)
]
)