mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix Transform3d.stack of compositions
Summary: Add a test for Transform3d.stack, and make it work with composed transformations. Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 . Reviewed By: patricklabatut Differential Revision: D34211920 fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
This commit is contained in:
parent
2a1de3b610
commit
c8f3d6bc0b
@ -1649,7 +1649,7 @@ def look_at_view_transform(
|
||||
elev=0.0,
|
||||
azim=0.0,
|
||||
degrees: bool = True,
|
||||
eye: Optional[Sequence] = None,
|
||||
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
||||
at=((0, 0, 0),), # (1, 3)
|
||||
up=((0, 1, 0),), # (1, 3)
|
||||
device: Device = "cpu",
|
||||
|
@ -196,10 +196,10 @@ class Transform3d:
|
||||
index = [index]
|
||||
return self.__class__(matrix=self.get_matrix()[index])
|
||||
|
||||
def compose(self, *others):
|
||||
def compose(self, *others: "Transform3d") -> "Transform3d":
|
||||
"""
|
||||
Return a new Transform3d with the transforms to compose stored as
|
||||
an internal list.
|
||||
Return a new Transform3d representing the composition of self with the
|
||||
given other transforms, which will be stored as an internal list.
|
||||
|
||||
Args:
|
||||
*others: Any number of Transform3d objects
|
||||
@ -216,7 +216,7 @@ class Transform3d:
|
||||
out._transforms = self._transforms + list(others)
|
||||
return out
|
||||
|
||||
def get_matrix(self):
|
||||
def get_matrix(self) -> torch.Tensor:
|
||||
"""
|
||||
Return a matrix which is the result of composing this transform
|
||||
with others stored in self.transforms. Where necessary transforms
|
||||
@ -240,13 +240,13 @@ class Transform3d:
|
||||
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
||||
return composed_matrix
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
return torch.inverse(self._matrix)
|
||||
|
||||
def inverse(self, invert_composed: bool = False):
|
||||
def inverse(self, invert_composed: bool = False) -> "Transform3d":
|
||||
"""
|
||||
Returns a new Transform3d object that represents an inverse of the
|
||||
current transformation.
|
||||
@ -295,14 +295,24 @@ class Transform3d:
|
||||
|
||||
return tinv
|
||||
|
||||
def stack(self, *others):
|
||||
def stack(self, *others: "Transform3d") -> "Transform3d":
|
||||
"""
|
||||
Return a new batched Transform3d representing the batch elements from
|
||||
self and all the given other transforms all batched together.
|
||||
|
||||
Args:
|
||||
*others: Any number of Transform3d objects
|
||||
|
||||
Returns:
|
||||
A new Transform3d.
|
||||
"""
|
||||
transforms = [self] + list(others)
|
||||
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
||||
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
|
||||
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||
out._matrix = matrix
|
||||
return out
|
||||
|
||||
def transform_points(self, points, eps: Optional[float] = None):
|
||||
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
|
||||
"""
|
||||
Use this transform to transform a set of 3D points. Assumes row major
|
||||
ordering of the input points.
|
||||
@ -347,7 +357,7 @@ class Transform3d:
|
||||
|
||||
return points_out
|
||||
|
||||
def transform_normals(self, normals):
|
||||
def transform_normals(self, normals) -> torch.Tensor:
|
||||
"""
|
||||
Use this transform to transform a set of normal vectors.
|
||||
|
||||
@ -379,19 +389,19 @@ class Transform3d:
|
||||
|
||||
return normals_out
|
||||
|
||||
def translate(self, *args, **kwargs):
|
||||
def translate(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Translate(device=self.device, *args, **kwargs))
|
||||
|
||||
def scale(self, *args, **kwargs):
|
||||
def scale(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate(self, *args, **kwargs):
|
||||
def rotate(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs):
|
||||
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "Transform3d":
|
||||
"""
|
||||
Deep copy of Transforms object. All internal tensors are cloned
|
||||
individually.
|
||||
@ -411,7 +421,7 @@ class Transform3d:
|
||||
device: Device,
|
||||
copy: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> "Transform3d":
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
@ -448,10 +458,10 @@ class Transform3d:
|
||||
]
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
def cpu(self) -> "Transform3d":
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self):
|
||||
def cuda(self) -> "Transform3d":
|
||||
return self.to("cuda")
|
||||
|
||||
|
||||
@ -486,7 +496,7 @@ class Translate(Transform3d):
|
||||
mat[:, 3, :3] = xyz
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
@ -533,7 +543,7 @@ class Scale(Transform3d):
|
||||
mat[:, 2, 2] = xyz[:, 2]
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
@ -575,7 +585,7 @@ class Rotate(Transform3d):
|
||||
mat[:, :3, :3] = R
|
||||
self._matrix = mat
|
||||
|
||||
def _get_matrix_inverse(self):
|
||||
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the inverse of self._matrix.
|
||||
"""
|
||||
@ -622,7 +632,7 @@ class RotateAxisAngle(Rotate):
|
||||
super().__init__(device=angle.device, R=R)
|
||||
|
||||
|
||||
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
|
||||
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Helper function for _handle_input.
|
||||
|
||||
@ -649,7 +659,7 @@ def _handle_input(
|
||||
device: Optional[Device],
|
||||
name: str,
|
||||
allow_singleton: bool = False,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function to handle parsing logic for building transforms. The output
|
||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||
@ -707,7 +717,9 @@ def _handle_input(
|
||||
return xyz
|
||||
|
||||
|
||||
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
|
||||
def _handle_angle_input(
|
||||
x, dtype: torch.dtype, device: Optional[Device], name: str
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function for building a rotation function using angles.
|
||||
The output is always of shape (N,).
|
||||
@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
|
||||
return _handle_coord(x, dtype, device_)
|
||||
|
||||
|
||||
def _broadcast_bmm(a, b):
|
||||
def _broadcast_bmm(a, b) -> torch.Tensor:
|
||||
"""
|
||||
Batch multiply two matrices and broadcast if necessary.
|
||||
|
||||
|
@ -10,6 +10,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms import random_rotations
|
||||
from pytorch3d.transforms.so3 import so3_exp_map
|
||||
from pytorch3d.transforms.transform3d import (
|
||||
Rotate,
|
||||
@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import (
|
||||
|
||||
|
||||
class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_to(self):
|
||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||
@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(IndexError):
|
||||
t3d_selected = t3d[invalid_index]
|
||||
|
||||
def test_stack(self):
|
||||
rotations = random_rotations(3)
|
||||
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
|
||||
transform1 = Scale(37)
|
||||
transform4 = transform1.stack(transform3)
|
||||
self.assertEqual(len(transform1), 1)
|
||||
self.assertEqual(len(transform3), 3)
|
||||
self.assertEqual(len(transform4), 4)
|
||||
self.assertClose(
|
||||
transform4.get_matrix(),
|
||||
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
|
||||
)
|
||||
points = torch.rand(4, 5, 3)
|
||||
new_points_expect = torch.cat(
|
||||
[
|
||||
transform1.transform_points(points[:1]),
|
||||
transform3.transform_points(points[1:]),
|
||||
]
|
||||
)
|
||||
new_points = transform4.transform_points(points)
|
||||
self.assertClose(new_points, new_points_expect)
|
||||
|
||||
|
||||
class TestTranslate(unittest.TestCase):
|
||||
def test_python_scalar(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user