__getitem__ for Transform3D

Summary: Implements the `__getitem__` method for `Transform3D`

Reviewed By: nikhilaravi

Differential Revision: D23813975

fbshipit-source-id: 5da752ed8ea029ad0af58bb7a7856f0995519b7a
This commit is contained in:
David Novotny 2021-01-05 03:37:38 -08:00 committed by Facebook GitHub Bot
parent ac3f8dc833
commit 1e4a2e8624
2 changed files with 135 additions and 2 deletions

View File

@ -2,7 +2,7 @@
import math
import warnings
from typing import Optional
from typing import List, Optional, Union
import torch
@ -172,6 +172,22 @@ class Transform3d:
def __len__(self):
return self.get_matrix().shape[0]
def __getitem__(
self, index: Union[int, List[int], slice, torch.Tensor]
) -> "Transform3d":
"""
Args:
index: Specifying the index of the transform to retrieve.
Can be an int, slice, list of ints, boolean, long tensor.
Supports negative indices.
Returns:
Transform3d object with selected transforms. The tensors are not cloned.
"""
if isinstance(index, int):
index = [index]
return self.__class__(matrix=self.get_matrix()[index])
def compose(self, *others):
"""
Return a new Transform3d with the tranforms to compose stored as
@ -361,6 +377,9 @@ class Transform3d:
def scale(self, *args, **kwargs):
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate(self, *args, **kwargs):
return self.compose(Rotate(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))

View File

@ -5,6 +5,7 @@ import math
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.so3 import so3_exponential_map
from pytorch3d.transforms.transform3d import (
Rotate,
@ -15,7 +16,7 @@ from pytorch3d.transforms.transform3d import (
)
class TestTransform(unittest.TestCase):
class TestTransform(TestCaseMixin, unittest.TestCase):
def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
@ -89,6 +90,22 @@ class TestTransform(unittest.TestCase):
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_rotate(self):
R = so3_exponential_map(torch.randn((1, 3)))
t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
1, 3, 3
)
normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
points_out_expected = torch.bmm(points, R)
normals_out_expected = torch.bmm(normals, R)
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_scale(self):
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
@ -237,6 +254,103 @@ class TestTransform(unittest.TestCase):
for m in (m1, m2, m3, m4):
self.assertTrue(torch.allclose(m, m5, atol=1e-3))
def _check_indexed_transforms(self, t3d, t3d_selected, indices):
t3d_matrix = t3d.get_matrix()
t3d_selected_matrix = t3d_selected.get_matrix()
for order_index, selected_index in indices:
self.assertClose(
t3d_matrix[selected_index], t3d_selected_matrix[order_index]
)
def test_get_item(self, batch_size=5):
device = torch.device("cuda:0")
matrices = torch.randn(
size=[batch_size, 4, 4], device=device, dtype=torch.float32
)
# init the Transforms3D class
t3d = Transform3d(matrix=matrices)
# int index
index = 1
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 1)
self._check_indexed_transforms(t3d, t3d_selected, [(0, 1)])
# negative int index
index = -1
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 1)
self._check_indexed_transforms(t3d, t3d_selected, [(0, -1)])
# list index
index = [1, 2]
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), len(index))
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index))
# empty list index
index = []
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 0)
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
# slice index
index = slice(0, 2, 1)
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 2)
self._check_indexed_transforms(t3d, t3d_selected, [(0, 0), (1, 1)])
# empty slice index
index = slice(0, 0, 1)
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 0)
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
# bool tensor
index = (torch.rand(batch_size) > 0.5).to(device)
index[:2] = True # make sure smth is selected
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), index.sum())
self._check_indexed_transforms(
t3d,
t3d_selected,
zip(
torch.arange(index.sum()),
torch.nonzero(index, as_tuple=False).squeeze(),
),
)
# all false bool tensor
index = torch.zeros(batch_size).bool()
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), 0)
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
# int tensor
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), index.numel())
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
# negative int tensor
index = -(torch.tensor([1, 2], dtype=torch.int64, device=device))
t3d_selected = t3d[index]
self.assertEqual(len(t3d_selected), index.numel())
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
# invalid index
for invalid_index in (
torch.tensor([1, 0, 1], dtype=torch.float32, device=device), # float tensor
1.2, # float index
torch.tensor(
[[1, 0, 1], [1, 0, 1]], dtype=torch.int32, device=device
), # multidimensional tensor
):
with self.assertRaises(IndexError):
t3d_selected = t3d[invalid_index]
class TestTranslate(unittest.TestCase):
def test_python_scalar(self):