Fix Transform3d.stack of compositions

Summary:
Add a test for Transform3d.stack, and make it work with composed transformations.

Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 .

Reviewed By: patricklabatut

Differential Revision: D34211920

fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
This commit is contained in:
Jeremy Reizenstein 2022-02-15 06:46:38 -08:00 committed by Facebook GitHub Bot
parent 2a1de3b610
commit c8f3d6bc0b
3 changed files with 64 additions and 26 deletions

View File

@ -1649,7 +1649,7 @@ def look_at_view_transform(
elev=0.0,
azim=0.0,
degrees: bool = True,
eye: Optional[Sequence] = None,
eye: Optional[Union[Sequence, torch.Tensor]] = None,
at=((0, 0, 0),), # (1, 3)
up=((0, 1, 0),), # (1, 3)
device: Device = "cpu",

View File

@ -196,10 +196,10 @@ class Transform3d:
index = [index]
return self.__class__(matrix=self.get_matrix()[index])
def compose(self, *others):
def compose(self, *others: "Transform3d") -> "Transform3d":
"""
Return a new Transform3d with the transforms to compose stored as
an internal list.
Return a new Transform3d representing the composition of self with the
given other transforms, which will be stored as an internal list.
Args:
*others: Any number of Transform3d objects
@ -216,7 +216,7 @@ class Transform3d:
out._transforms = self._transforms + list(others)
return out
def get_matrix(self):
def get_matrix(self) -> torch.Tensor:
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
@ -240,13 +240,13 @@ class Transform3d:
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix
def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
return torch.inverse(self._matrix)
def inverse(self, invert_composed: bool = False):
def inverse(self, invert_composed: bool = False) -> "Transform3d":
"""
Returns a new Transform3d object that represents an inverse of the
current transformation.
@ -295,14 +295,24 @@ class Transform3d:
return tinv
def stack(self, *others):
def stack(self, *others: "Transform3d") -> "Transform3d":
"""
Return a new batched Transform3d representing the batch elements from
self and all the given other transforms all batched together.
Args:
*others: Any number of Transform3d objects
Returns:
A new Transform3d.
"""
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
out = Transform3d(dtype=self.dtype, device=self.device)
out._matrix = matrix
return out
def transform_points(self, points, eps: Optional[float] = None):
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
"""
Use this transform to transform a set of 3D points. Assumes row major
ordering of the input points.
@ -347,7 +357,7 @@ class Transform3d:
return points_out
def transform_normals(self, normals):
def transform_normals(self, normals) -> torch.Tensor:
"""
Use this transform to transform a set of normal vectors.
@ -379,19 +389,19 @@ class Transform3d:
return normals_out
def translate(self, *args, **kwargs):
def translate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Translate(device=self.device, *args, **kwargs))
def scale(self, *args, **kwargs):
def scale(self, *args, **kwargs) -> "Transform3d":
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate(self, *args, **kwargs):
def rotate(self, *args, **kwargs) -> "Transform3d":
return self.compose(Rotate(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
def clone(self):
def clone(self) -> "Transform3d":
"""
Deep copy of Transforms object. All internal tensors are cloned
individually.
@ -411,7 +421,7 @@ class Transform3d:
device: Device,
copy: bool = False,
dtype: Optional[torch.dtype] = None,
):
) -> "Transform3d":
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
@ -448,10 +458,10 @@ class Transform3d:
]
return other
def cpu(self):
def cpu(self) -> "Transform3d":
return self.to("cpu")
def cuda(self):
def cuda(self) -> "Transform3d":
return self.to("cuda")
@ -486,7 +496,7 @@ class Translate(Transform3d):
mat[:, 3, :3] = xyz
self._matrix = mat
def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@ -533,7 +543,7 @@ class Scale(Transform3d):
mat[:, 2, 2] = xyz[:, 2]
self._matrix = mat
def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@ -575,7 +585,7 @@ class Rotate(Transform3d):
mat[:, :3, :3] = R
self._matrix = mat
def _get_matrix_inverse(self):
def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@ -622,7 +632,7 @@ class RotateAxisAngle(Rotate):
super().__init__(device=angle.device, R=R)
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""
Helper function for _handle_input.
@ -649,7 +659,7 @@ def _handle_input(
device: Optional[Device],
name: str,
allow_singleton: bool = False,
):
) -> torch.Tensor:
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
@ -707,7 +717,9 @@ def _handle_input(
return xyz
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
def _handle_angle_input(
x, dtype: torch.dtype, device: Optional[Device], name: str
) -> torch.Tensor:
"""
Helper function for building a rotation function using angles.
The output is always of shape (N,).
@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
return _handle_coord(x, dtype, device_)
def _broadcast_bmm(a, b):
def _broadcast_bmm(a, b) -> torch.Tensor:
"""
Batch multiply two matrices and broadcast if necessary.

View File

@ -10,6 +10,7 @@ import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms import random_rotations
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.transforms.transform3d import (
Rotate,
@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import (
class TestTransform(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
with self.assertRaises(IndexError):
t3d_selected = t3d[invalid_index]
def test_stack(self):
rotations = random_rotations(3)
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
transform1 = Scale(37)
transform4 = transform1.stack(transform3)
self.assertEqual(len(transform1), 1)
self.assertEqual(len(transform3), 3)
self.assertEqual(len(transform4), 4)
self.assertClose(
transform4.get_matrix(),
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
)
points = torch.rand(4, 5, 3)
new_points_expect = torch.cat(
[
transform1.transform_points(points[:1]),
transform3.transform_points(points[1:]),
]
)
new_points = transform4.transform_points(points)
self.assertClose(new_points, new_points_expect)
class TestTranslate(unittest.TestCase):
def test_python_scalar(self):