mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Fix Transform3d.stack of compositions
Summary: Add a test for Transform3d.stack, and make it work with composed transformations. Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 . Reviewed By: patricklabatut Differential Revision: D34211920 fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
2a1de3b610
commit
c8f3d6bc0b
@@ -10,6 +10,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms import random_rotations
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
from pytorch3d.transforms.transform3d import (
|
||||
Rotate,
|
||||
@@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import (
|
||||
|
||||
|
||||
class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_to(self):
|
||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||
@@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(IndexError):
|
||||
t3d_selected = t3d[invalid_index]
|
||||
|
||||
def test_stack(self):
|
||||
rotations = random_rotations(3)
|
||||
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
|
||||
transform1 = Scale(37)
|
||||
transform4 = transform1.stack(transform3)
|
||||
self.assertEqual(len(transform1), 1)
|
||||
self.assertEqual(len(transform3), 3)
|
||||
self.assertEqual(len(transform4), 4)
|
||||
self.assertClose(
|
||||
transform4.get_matrix(),
|
||||
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
|
||||
)
|
||||
points = torch.rand(4, 5, 3)
|
||||
new_points_expect = torch.cat(
|
||||
[
|
||||
transform1.transform_points(points[:1]),
|
||||
transform3.transform_points(points[1:]),
|
||||
]
|
||||
)
|
||||
new_points = transform4.transform_points(points)
|
||||
self.assertClose(new_points, new_points_expect)
|
||||
|
||||
|
||||
class TestTranslate(unittest.TestCase):
|
||||
def test_python_scalar(self):
|
||||
|
||||
Reference in New Issue
Block a user