mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
Initialization of Transform3D with a custom matrix.
Summary: Allows to initialize a Transform3D object with a batch of user-defined transformation matrices: ``` t = Transform3D(matrix=torch.randn(2, 4, 4)) ``` Reviewed By: nikhilaravi Differential Revision: D20693475 fbshipit-source-id: dccc49b2ca4c19a034844c63463953ba8f52c1bc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e37085d999
commit
90dc7a0856
@@ -134,8 +134,37 @@ class Transform3d:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dtype=torch.float32, device="cpu"):
|
||||
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
|
||||
def __init__(
|
||||
self,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device="cpu",
|
||||
matrix: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dtype: The data type of the transformation matrix.
|
||||
to be used if `matrix = None`.
|
||||
device: The device for storing the implemented transformation.
|
||||
If `matrix != None`, uses the device of input `matrix`.
|
||||
matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4)
|
||||
representing the 4x4 3D transformation matrix.
|
||||
If `None`, initializes with identity using
|
||||
the specified `device` and `dtype`.
|
||||
"""
|
||||
|
||||
if matrix is None:
|
||||
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
|
||||
else:
|
||||
if matrix.ndim not in (2, 3):
|
||||
raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.')
|
||||
if matrix.shape[-2] != 4 or matrix.shape[-1] != 4:
|
||||
raise ValueError(
|
||||
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
|
||||
)
|
||||
# set the device from matrix
|
||||
device = matrix.device
|
||||
self._matrix = matrix.view(-1, 4, 4)
|
||||
|
||||
self._transforms = [] # store transforms to compose
|
||||
self._lu = None
|
||||
self.device = device
|
||||
|
||||
Reference in New Issue
Block a user