mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
__getitem__ for Transform3D
Summary: Implements the `__getitem__` method for `Transform3D` Reviewed By: nikhilaravi Differential Revision: D23813975 fbshipit-source-id: 5da752ed8ea029ad0af58bb7a7856f0995519b7a
This commit is contained in:
parent
ac3f8dc833
commit
1e4a2e8624
@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -172,6 +172,22 @@ class Transform3d:
|
||||
def __len__(self):
|
||||
return self.get_matrix().shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.Tensor]
|
||||
) -> "Transform3d":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the transform to retrieve.
|
||||
Can be an int, slice, list of ints, boolean, long tensor.
|
||||
Supports negative indices.
|
||||
|
||||
Returns:
|
||||
Transform3d object with selected transforms. The tensors are not cloned.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
return self.__class__(matrix=self.get_matrix()[index])
|
||||
|
||||
def compose(self, *others):
|
||||
"""
|
||||
Return a new Transform3d with the tranforms to compose stored as
|
||||
@ -361,6 +377,9 @@ class Transform3d:
|
||||
def scale(self, *args, **kwargs):
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate(self, *args, **kwargs):
|
||||
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs):
|
||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||
|
||||
|
@ -5,6 +5,7 @@ import math
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.transforms.so3 import so3_exponential_map
|
||||
from pytorch3d.transforms.transform3d import (
|
||||
Rotate,
|
||||
@ -15,7 +16,7 @@ from pytorch3d.transforms.transform3d import (
|
||||
)
|
||||
|
||||
|
||||
class TestTransform(unittest.TestCase):
|
||||
class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
def test_to(self):
|
||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||
@ -89,6 +90,22 @@ class TestTransform(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
|
||||
def test_rotate(self):
|
||||
R = so3_exponential_map(torch.randn((1, 3)))
|
||||
t = Transform3d().rotate(R)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
1, 3, 3
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
|
||||
).view(1, 3, 3)
|
||||
points_out = t.transform_points(points)
|
||||
normals_out = t.transform_normals(normals)
|
||||
points_out_expected = torch.bmm(points, R)
|
||||
normals_out_expected = torch.bmm(normals, R)
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
|
||||
def test_scale(self):
|
||||
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
@ -237,6 +254,103 @@ class TestTransform(unittest.TestCase):
|
||||
for m in (m1, m2, m3, m4):
|
||||
self.assertTrue(torch.allclose(m, m5, atol=1e-3))
|
||||
|
||||
def _check_indexed_transforms(self, t3d, t3d_selected, indices):
|
||||
t3d_matrix = t3d.get_matrix()
|
||||
t3d_selected_matrix = t3d_selected.get_matrix()
|
||||
for order_index, selected_index in indices:
|
||||
self.assertClose(
|
||||
t3d_matrix[selected_index], t3d_selected_matrix[order_index]
|
||||
)
|
||||
|
||||
def test_get_item(self, batch_size=5):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
matrices = torch.randn(
|
||||
size=[batch_size, 4, 4], device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
# init the Transforms3D class
|
||||
t3d = Transform3d(matrix=matrices)
|
||||
|
||||
# int index
|
||||
index = 1
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 1)
|
||||
self._check_indexed_transforms(t3d, t3d_selected, [(0, 1)])
|
||||
|
||||
# negative int index
|
||||
index = -1
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 1)
|
||||
self._check_indexed_transforms(t3d, t3d_selected, [(0, -1)])
|
||||
|
||||
# list index
|
||||
index = [1, 2]
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), len(index))
|
||||
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index))
|
||||
|
||||
# empty list index
|
||||
index = []
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 0)
|
||||
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
|
||||
|
||||
# slice index
|
||||
index = slice(0, 2, 1)
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 2)
|
||||
self._check_indexed_transforms(t3d, t3d_selected, [(0, 0), (1, 1)])
|
||||
|
||||
# empty slice index
|
||||
index = slice(0, 0, 1)
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 0)
|
||||
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
|
||||
|
||||
# bool tensor
|
||||
index = (torch.rand(batch_size) > 0.5).to(device)
|
||||
index[:2] = True # make sure smth is selected
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), index.sum())
|
||||
self._check_indexed_transforms(
|
||||
t3d,
|
||||
t3d_selected,
|
||||
zip(
|
||||
torch.arange(index.sum()),
|
||||
torch.nonzero(index, as_tuple=False).squeeze(),
|
||||
),
|
||||
)
|
||||
|
||||
# all false bool tensor
|
||||
index = torch.zeros(batch_size).bool()
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), 0)
|
||||
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
|
||||
|
||||
# int tensor
|
||||
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), index.numel())
|
||||
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
|
||||
|
||||
# negative int tensor
|
||||
index = -(torch.tensor([1, 2], dtype=torch.int64, device=device))
|
||||
t3d_selected = t3d[index]
|
||||
self.assertEqual(len(t3d_selected), index.numel())
|
||||
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
|
||||
|
||||
# invalid index
|
||||
for invalid_index in (
|
||||
torch.tensor([1, 0, 1], dtype=torch.float32, device=device), # float tensor
|
||||
1.2, # float index
|
||||
torch.tensor(
|
||||
[[1, 0, 1], [1, 0, 1]], dtype=torch.int32, device=device
|
||||
), # multidimensional tensor
|
||||
):
|
||||
with self.assertRaises(IndexError):
|
||||
t3d_selected = t3d[invalid_index]
|
||||
|
||||
|
||||
class TestTranslate(unittest.TestCase):
|
||||
def test_python_scalar(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user