mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Fix Transform3d.stack of compositions
Summary: Add a test for Transform3d.stack, and make it work with composed transformations. Fixes https://github.com/facebookresearch/pytorch3d/issues/1072 . Reviewed By: patricklabatut Differential Revision: D34211920 fbshipit-source-id: bfbd0895494ca2ad3d08a61bc82ba23637e168cc
This commit is contained in:
parent
2a1de3b610
commit
c8f3d6bc0b
@ -1649,7 +1649,7 @@ def look_at_view_transform(
|
|||||||
elev=0.0,
|
elev=0.0,
|
||||||
azim=0.0,
|
azim=0.0,
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
eye: Optional[Sequence] = None,
|
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
||||||
at=((0, 0, 0),), # (1, 3)
|
at=((0, 0, 0),), # (1, 3)
|
||||||
up=((0, 1, 0),), # (1, 3)
|
up=((0, 1, 0),), # (1, 3)
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
|
@ -196,10 +196,10 @@ class Transform3d:
|
|||||||
index = [index]
|
index = [index]
|
||||||
return self.__class__(matrix=self.get_matrix()[index])
|
return self.__class__(matrix=self.get_matrix()[index])
|
||||||
|
|
||||||
def compose(self, *others):
|
def compose(self, *others: "Transform3d") -> "Transform3d":
|
||||||
"""
|
"""
|
||||||
Return a new Transform3d with the transforms to compose stored as
|
Return a new Transform3d representing the composition of self with the
|
||||||
an internal list.
|
given other transforms, which will be stored as an internal list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*others: Any number of Transform3d objects
|
*others: Any number of Transform3d objects
|
||||||
@ -216,7 +216,7 @@ class Transform3d:
|
|||||||
out._transforms = self._transforms + list(others)
|
out._transforms = self._transforms + list(others)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_matrix(self):
|
def get_matrix(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return a matrix which is the result of composing this transform
|
Return a matrix which is the result of composing this transform
|
||||||
with others stored in self.transforms. Where necessary transforms
|
with others stored in self.transforms. Where necessary transforms
|
||||||
@ -240,13 +240,13 @@ class Transform3d:
|
|||||||
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
||||||
return composed_matrix
|
return composed_matrix
|
||||||
|
|
||||||
def _get_matrix_inverse(self):
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the inverse of self._matrix.
|
Return the inverse of self._matrix.
|
||||||
"""
|
"""
|
||||||
return torch.inverse(self._matrix)
|
return torch.inverse(self._matrix)
|
||||||
|
|
||||||
def inverse(self, invert_composed: bool = False):
|
def inverse(self, invert_composed: bool = False) -> "Transform3d":
|
||||||
"""
|
"""
|
||||||
Returns a new Transform3d object that represents an inverse of the
|
Returns a new Transform3d object that represents an inverse of the
|
||||||
current transformation.
|
current transformation.
|
||||||
@ -295,14 +295,24 @@ class Transform3d:
|
|||||||
|
|
||||||
return tinv
|
return tinv
|
||||||
|
|
||||||
def stack(self, *others):
|
def stack(self, *others: "Transform3d") -> "Transform3d":
|
||||||
|
"""
|
||||||
|
Return a new batched Transform3d representing the batch elements from
|
||||||
|
self and all the given other transforms all batched together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*others: Any number of Transform3d objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new Transform3d.
|
||||||
|
"""
|
||||||
transforms = [self] + list(others)
|
transforms = [self] + list(others)
|
||||||
matrix = torch.cat([t._matrix for t in transforms], dim=0)
|
matrix = torch.cat([t.get_matrix() for t in transforms], dim=0)
|
||||||
out = Transform3d(dtype=self.dtype, device=self.device)
|
out = Transform3d(dtype=self.dtype, device=self.device)
|
||||||
out._matrix = matrix
|
out._matrix = matrix
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def transform_points(self, points, eps: Optional[float] = None):
|
def transform_points(self, points, eps: Optional[float] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Use this transform to transform a set of 3D points. Assumes row major
|
Use this transform to transform a set of 3D points. Assumes row major
|
||||||
ordering of the input points.
|
ordering of the input points.
|
||||||
@ -347,7 +357,7 @@ class Transform3d:
|
|||||||
|
|
||||||
return points_out
|
return points_out
|
||||||
|
|
||||||
def transform_normals(self, normals):
|
def transform_normals(self, normals) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Use this transform to transform a set of normal vectors.
|
Use this transform to transform a set of normal vectors.
|
||||||
|
|
||||||
@ -379,19 +389,19 @@ class Transform3d:
|
|||||||
|
|
||||||
return normals_out
|
return normals_out
|
||||||
|
|
||||||
def translate(self, *args, **kwargs):
|
def translate(self, *args, **kwargs) -> "Transform3d":
|
||||||
return self.compose(Translate(device=self.device, *args, **kwargs))
|
return self.compose(Translate(device=self.device, *args, **kwargs))
|
||||||
|
|
||||||
def scale(self, *args, **kwargs):
|
def scale(self, *args, **kwargs) -> "Transform3d":
|
||||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||||
|
|
||||||
def rotate(self, *args, **kwargs):
|
def rotate(self, *args, **kwargs) -> "Transform3d":
|
||||||
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
||||||
|
|
||||||
def rotate_axis_angle(self, *args, **kwargs):
|
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
|
||||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> "Transform3d":
|
||||||
"""
|
"""
|
||||||
Deep copy of Transforms object. All internal tensors are cloned
|
Deep copy of Transforms object. All internal tensors are cloned
|
||||||
individually.
|
individually.
|
||||||
@ -411,7 +421,7 @@ class Transform3d:
|
|||||||
device: Device,
|
device: Device,
|
||||||
copy: bool = False,
|
copy: bool = False,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
) -> "Transform3d":
|
||||||
"""
|
"""
|
||||||
Match functionality of torch.Tensor.to()
|
Match functionality of torch.Tensor.to()
|
||||||
If copy = True or the self Tensor is on a different device, the
|
If copy = True or the self Tensor is on a different device, the
|
||||||
@ -448,10 +458,10 @@ class Transform3d:
|
|||||||
]
|
]
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def cpu(self):
|
def cpu(self) -> "Transform3d":
|
||||||
return self.to("cpu")
|
return self.to("cpu")
|
||||||
|
|
||||||
def cuda(self):
|
def cuda(self) -> "Transform3d":
|
||||||
return self.to("cuda")
|
return self.to("cuda")
|
||||||
|
|
||||||
|
|
||||||
@ -486,7 +496,7 @@ class Translate(Transform3d):
|
|||||||
mat[:, 3, :3] = xyz
|
mat[:, 3, :3] = xyz
|
||||||
self._matrix = mat
|
self._matrix = mat
|
||||||
|
|
||||||
def _get_matrix_inverse(self):
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the inverse of self._matrix.
|
Return the inverse of self._matrix.
|
||||||
"""
|
"""
|
||||||
@ -533,7 +543,7 @@ class Scale(Transform3d):
|
|||||||
mat[:, 2, 2] = xyz[:, 2]
|
mat[:, 2, 2] = xyz[:, 2]
|
||||||
self._matrix = mat
|
self._matrix = mat
|
||||||
|
|
||||||
def _get_matrix_inverse(self):
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the inverse of self._matrix.
|
Return the inverse of self._matrix.
|
||||||
"""
|
"""
|
||||||
@ -575,7 +585,7 @@ class Rotate(Transform3d):
|
|||||||
mat[:, :3, :3] = R
|
mat[:, :3, :3] = R
|
||||||
self._matrix = mat
|
self._matrix = mat
|
||||||
|
|
||||||
def _get_matrix_inverse(self):
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the inverse of self._matrix.
|
Return the inverse of self._matrix.
|
||||||
"""
|
"""
|
||||||
@ -622,7 +632,7 @@ class RotateAxisAngle(Rotate):
|
|||||||
super().__init__(device=angle.device, R=R)
|
super().__init__(device=angle.device, R=R)
|
||||||
|
|
||||||
|
|
||||||
def _handle_coord(c, dtype: torch.dtype, device: torch.device):
|
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Helper function for _handle_input.
|
Helper function for _handle_input.
|
||||||
|
|
||||||
@ -649,7 +659,7 @@ def _handle_input(
|
|||||||
device: Optional[Device],
|
device: Optional[Device],
|
||||||
name: str,
|
name: str,
|
||||||
allow_singleton: bool = False,
|
allow_singleton: bool = False,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Helper function to handle parsing logic for building transforms. The output
|
Helper function to handle parsing logic for building transforms. The output
|
||||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||||
@ -707,7 +717,9 @@ def _handle_input(
|
|||||||
return xyz
|
return xyz
|
||||||
|
|
||||||
|
|
||||||
def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str):
|
def _handle_angle_input(
|
||||||
|
x, dtype: torch.dtype, device: Optional[Device], name: str
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Helper function for building a rotation function using angles.
|
Helper function for building a rotation function using angles.
|
||||||
The output is always of shape (N,).
|
The output is always of shape (N,).
|
||||||
@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
|
|||||||
return _handle_coord(x, dtype, device_)
|
return _handle_coord(x, dtype, device_)
|
||||||
|
|
||||||
|
|
||||||
def _broadcast_bmm(a, b):
|
def _broadcast_bmm(a, b) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Batch multiply two matrices and broadcast if necessary.
|
Batch multiply two matrices and broadcast if necessary.
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.transforms import random_rotations
|
||||||
from pytorch3d.transforms.so3 import so3_exp_map
|
from pytorch3d.transforms.so3 import so3_exp_map
|
||||||
from pytorch3d.transforms.transform3d import (
|
from pytorch3d.transforms.transform3d import (
|
||||||
Rotate,
|
Rotate,
|
||||||
@ -21,6 +22,9 @@ from pytorch3d.transforms.transform3d import (
|
|||||||
|
|
||||||
|
|
||||||
class TestTransform(TestCaseMixin, unittest.TestCase):
|
class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
def test_to(self):
|
def test_to(self):
|
||||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||||
@ -406,6 +410,28 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
t3d_selected = t3d[invalid_index]
|
t3d_selected = t3d[invalid_index]
|
||||||
|
|
||||||
|
def test_stack(self):
|
||||||
|
rotations = random_rotations(3)
|
||||||
|
transform3 = Transform3d().rotate(rotations).translate(torch.full((3, 3), 0.3))
|
||||||
|
transform1 = Scale(37)
|
||||||
|
transform4 = transform1.stack(transform3)
|
||||||
|
self.assertEqual(len(transform1), 1)
|
||||||
|
self.assertEqual(len(transform3), 3)
|
||||||
|
self.assertEqual(len(transform4), 4)
|
||||||
|
self.assertClose(
|
||||||
|
transform4.get_matrix(),
|
||||||
|
torch.cat([transform1.get_matrix(), transform3.get_matrix()]),
|
||||||
|
)
|
||||||
|
points = torch.rand(4, 5, 3)
|
||||||
|
new_points_expect = torch.cat(
|
||||||
|
[
|
||||||
|
transform1.transform_points(points[:1]),
|
||||||
|
transform3.transform_points(points[1:]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
new_points = transform4.transform_points(points)
|
||||||
|
self.assertClose(new_points, new_points_expect)
|
||||||
|
|
||||||
|
|
||||||
class TestTranslate(unittest.TestCase):
|
class TestTranslate(unittest.TestCase):
|
||||||
def test_python_scalar(self):
|
def test_python_scalar(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user