Fix Transform3d.stack of compositions

Summary:
Add a test for Transform3d.stack, and make it work with composed transformations.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 .

Reviewed By: patricklabatut

Differential Revision: D34211920

fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
This commit is contained in:
Jeremy Reizenstein
2022-02-15 06:46:38 -08:00
committed by Facebook GitHub Bot
parent 2a1de3b610
commit c8f3d6bc0b
3 changed files with 64 additions and 26 deletions

View File

@@ -10,6 +10,7 @@ import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.transforms.transform3d import (
Rotate,
@@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import (
class TestTransform(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
@@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
with self.assertRaises(IndexError):
t3d_selected = t3d[invalid_index]
def test_stack(self):
rotations = random_rotations(3)
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
transform1 = Scale(37)
transform4 = transform1.stack(transform3)
self.assertEqual(len(transform1), 1)
self.assertEqual(len(transform3), 3)
self.assertEqual(len(transform4), 4)
self.assertClose(
transform4.get_matrix(),
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
)
points = torch.rand(4, 5, 3)
new_points_expect = torch.cat(
[
transform1.transform_points(points[:1]),
transform3.transform_points(points[1:]),
]
)
new_points = transform4.transform_points(points)
self.assertClose(new_points, new_points_expect)
class TestTranslate(unittest.TestCase):
def test_python_scalar(self):