mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
__getitem__ for Transform3D
Summary: Implements the `__getitem__` method for `Transform3D` Reviewed By: nikhilaravi Differential Revision: D23813975 fbshipit-source-id: 5da752ed8ea029ad0af58bb7a7856f0995519b7a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ac3f8dc833
commit
1e4a2e8624
@@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -172,6 +172,22 @@ class Transform3d:
|
||||
def __len__(self):
|
||||
return self.get_matrix().shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], slice, torch.Tensor]
|
||||
) -> "Transform3d":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the transform to retrieve.
|
||||
Can be an int, slice, list of ints, boolean, long tensor.
|
||||
Supports negative indices.
|
||||
|
||||
Returns:
|
||||
Transform3d object with selected transforms. The tensors are not cloned.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
index = [index]
|
||||
return self.__class__(matrix=self.get_matrix()[index])
|
||||
|
||||
def compose(self, *others):
|
||||
"""
|
||||
Return a new Transform3d with the tranforms to compose stored as
|
||||
@@ -361,6 +377,9 @@ class Transform3d:
|
||||
def scale(self, *args, **kwargs):
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate(self, *args, **kwargs):
|
||||
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs):
|
||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user