mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 23:00:34 +08:00
Support for multi-dimensional list_to_padded/padded_to_list.
Summary: Extends `list_to_padded`/`padded_to_list` to work for tensors with an arbitrary number of input dimensions. Reviewed By: nikhilaravi, gkioxari Differential Revision: D23813969 fbshipit-source-id: 52c212a2ecdb3c4dfb6ac47217715e07998f37f1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0ba55a83ad
commit
b4dea43963
@@ -1,77 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List, Union
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
Util functions containing representation transforms for points/verts/faces.
|
||||
Util functions for points/verts/faces/volumes.
|
||||
"""
|
||||
|
||||
|
||||
def list_to_padded(
|
||||
x: List[torch.Tensor],
|
||||
pad_size: Union[list, tuple, None] = None,
|
||||
pad_size: Union[Sequence[int], None] = None,
|
||||
pad_value: float = 0.0,
|
||||
equisized: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Transforms a list of N tensors each of shape (Mi, Ki) into a single tensor
|
||||
of shape (N, pad_size(0), pad_size(1)), or (N, max(Mi), max(Ki))
|
||||
if pad_size is None.
|
||||
Transforms a list of N tensors each of shape (Si_0, Si_1, ... Si_D)
|
||||
into:
|
||||
- a single tensor of shape (N, pad_size(0), pad_size(1), ..., pad_size(D))
|
||||
if pad_size is provided
|
||||
- or a tensor of shape (N, max(Si_0), max(Si_1), ..., max(Si_D)) if pad_size is None.
|
||||
|
||||
Args:
|
||||
x: list of Tensors
|
||||
pad_size: list(int) specifying the size of the padded tensor
|
||||
pad_size: list(int) specifying the size of the padded tensor.
|
||||
If `None` (default), the largest size of each dimension
|
||||
is set as the `pad_size`.
|
||||
pad_value: float value to be used to fill the padded tensor
|
||||
equisized: bool indicating whether the items in x are of equal size
|
||||
(sometimes this is known and if provided saves computation)
|
||||
|
||||
Returns:
|
||||
x_padded: tensor consisting of padded input tensors
|
||||
x_padded: tensor consisting of padded input tensors stored
|
||||
over the newly allocated memory.
|
||||
"""
|
||||
if equisized:
|
||||
return torch.stack(x, 0)
|
||||
|
||||
if not all(torch.is_tensor(y) for y in x):
|
||||
raise ValueError("All items have to be instances of a torch.Tensor.")
|
||||
|
||||
# we set the common number of dimensions to the maximum
|
||||
# of the dimensionalities of the tensors in the list
|
||||
element_ndim = max(y.ndim for y in x)
|
||||
|
||||
# replace empty 1D tensors with empty tensors with a correct number of dimensions
|
||||
x = [
|
||||
(y.new_zeros([0] * element_ndim) if (y.ndim == 1 and y.nelement() == 0) else y)
|
||||
for y in x
|
||||
] # pyre-ignore
|
||||
|
||||
if any(y.ndim != x[0].ndim for y in x):
|
||||
raise ValueError("All items have to have the same number of dimensions!")
|
||||
|
||||
if pad_size is None:
|
||||
pad_dim0 = max(y.shape[0] for y in x if len(y) > 0)
|
||||
pad_dim1 = max(y.shape[1] for y in x if len(y) > 0)
|
||||
pad_dims = [
|
||||
max(y.shape[dim] for y in x if len(y) > 0) for dim in range(x[0].ndim)
|
||||
]
|
||||
else:
|
||||
if len(pad_size) != 2:
|
||||
raise ValueError("Pad size must contain target size for 1st and 2nd dim")
|
||||
pad_dim0, pad_dim1 = pad_size
|
||||
if any(len(pad_size) != y.ndim for y in x):
|
||||
raise ValueError("Pad size must contain target size for all dimensions.")
|
||||
pad_dims = pad_size
|
||||
|
||||
N = len(x)
|
||||
x_padded = torch.full(
|
||||
(N, pad_dim0, pad_dim1), pad_value, dtype=x[0].dtype, device=x[0].device
|
||||
)
|
||||
x_padded = x[0].new_full((N, *pad_dims), pad_value)
|
||||
for i, y in enumerate(x):
|
||||
if len(y) > 0:
|
||||
if y.ndim != 2:
|
||||
raise ValueError("Supports only 2-dimensional tensor items")
|
||||
x_padded[i, : y.shape[0], : y.shape[1]] = y
|
||||
slices = (i, *(slice(0, y.shape[dim]) for dim in range(y.ndim)))
|
||||
x_padded[slices] = y
|
||||
return x_padded
|
||||
|
||||
|
||||
def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None):
|
||||
def padded_to_list(
|
||||
x: torch.Tensor,
|
||||
split_size: Union[Sequence[int], Sequence[Sequence[int]], None] = None,
|
||||
):
|
||||
r"""
|
||||
Transforms a padded tensor of shape (N, M, K) into a list of N tensors
|
||||
of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape
|
||||
(M, K) if split_size is None.
|
||||
Support only for 3-dimensional input tensor.
|
||||
Transforms a padded tensor of shape (N, S_1, S_2, ..., S_D) into a list
|
||||
of N tensors of shape:
|
||||
- (Si_1, Si_2, ..., Si_D) where (Si_1, Si_2, ..., Si_D) is specified in split_size(i)
|
||||
- or (S_1, S_2, ..., S_D) if split_size is None
|
||||
- or (Si_1, S_2, ..., S_D) if split_size(i) is an integer.
|
||||
|
||||
Args:
|
||||
x: tensor
|
||||
split_size: list, tuple or int defining the number of items for each tensor
|
||||
in the output list.
|
||||
split_size: optional 1D or 2D list/tuple of ints defining the number of
|
||||
items for each tensor.
|
||||
|
||||
Returns:
|
||||
x_list: a list of tensors
|
||||
x_list: a list of tensors sharing the memory with the input.
|
||||
"""
|
||||
if x.ndim != 3:
|
||||
raise ValueError("Supports only 3-dimensional input tensors")
|
||||
|
||||
x_list = list(x.unbind(0))
|
||||
|
||||
if split_size is None:
|
||||
@@ -84,13 +104,9 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
|
||||
for i in range(N):
|
||||
if isinstance(split_size[i], int):
|
||||
x_list[i] = x_list[i][: split_size[i]]
|
||||
elif len(split_size[i]) == 2:
|
||||
x_list[i] = x_list[i][: split_size[i][0], : split_size[i][1]]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Support only for 2-dimensional unbinded tensor. \
|
||||
Split size for more dimensions provided"
|
||||
)
|
||||
slices = tuple(slice(0, s) for s in split_size[i]) # pyre-ignore
|
||||
x_list[i] = x_list[i][slices]
|
||||
return x_list
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user