Support for multi-dimensional list_to_padded/padded_to_list.

Summary: Extends `list_to_padded`/`padded_to_list` to work for tensors with an arbitrary number of input dimensions.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D23813969

fbshipit-source-id: 52c212a2ecdb3c4dfb6ac47217715e07998f37f1
This commit is contained in:
David Novotny 2021-01-04 09:41:28 -08:00 committed by Facebook GitHub Bot
parent 0ba55a83ad
commit b4dea43963
2 changed files with 131 additions and 81 deletions

View File

@ -1,77 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Union
from typing import List, Sequence, Union
import torch
"""
Util functions containing representation transforms for points/verts/faces.
Util functions for points/verts/faces/volumes.
"""
def list_to_padded(
x: List[torch.Tensor],
pad_size: Union[list, tuple, None] = None,
pad_size: Union[Sequence[int], None] = None,
pad_value: float = 0.0,
equisized: bool = False,
) -> torch.Tensor:
r"""
Transforms a list of N tensors each of shape (Mi, Ki) into a single tensor
of shape (N, pad_size(0), pad_size(1)), or (N, max(Mi), max(Ki))
if pad_size is None.
Transforms a list of N tensors each of shape (Si_0, Si_1, ... Si_D)
into:
- a single tensor of shape (N, pad_size(0), pad_size(1), ..., pad_size(D))
if pad_size is provided
- or a tensor of shape (N, max(Si_0), max(Si_1), ..., max(Si_D)) if pad_size is None.
Args:
x: list of Tensors
pad_size: list(int) specifying the size of the padded tensor
pad_size: list(int) specifying the size of the padded tensor.
If `None` (default), the largest size of each dimension
is set as the `pad_size`.
pad_value: float value to be used to fill the padded tensor
equisized: bool indicating whether the items in x are of equal size
(sometimes this is known and if provided saves computation)
Returns:
x_padded: tensor consisting of padded input tensors
x_padded: tensor consisting of padded input tensors stored
over the newly allocated memory.
"""
if equisized:
return torch.stack(x, 0)
if not all(torch.is_tensor(y) for y in x):
raise ValueError("All items have to be instances of a torch.Tensor.")
# we set the common number of dimensions to the maximum
# of the dimensionalities of the tensors in the list
element_ndim = max(y.ndim for y in x)
# replace empty 1D tensors with empty tensors with a correct number of dimensions
x = [
(y.new_zeros([0] * element_ndim) if (y.ndim == 1 and y.nelement() == 0) else y)
for y in x
] # pyre-ignore
if any(y.ndim != x[0].ndim for y in x):
raise ValueError("All items have to have the same number of dimensions!")
if pad_size is None:
pad_dim0 = max(y.shape[0] for y in x if len(y) > 0)
pad_dim1 = max(y.shape[1] for y in x if len(y) > 0)
pad_dims = [
max(y.shape[dim] for y in x if len(y) > 0) for dim in range(x[0].ndim)
]
else:
if len(pad_size) != 2:
raise ValueError("Pad size must contain target size for 1st and 2nd dim")
pad_dim0, pad_dim1 = pad_size
if any(len(pad_size) != y.ndim for y in x):
raise ValueError("Pad size must contain target size for all dimensions.")
pad_dims = pad_size
N = len(x)
x_padded = torch.full(
(N, pad_dim0, pad_dim1), pad_value, dtype=x[0].dtype, device=x[0].device
)
x_padded = x[0].new_full((N, *pad_dims), pad_value)
for i, y in enumerate(x):
if len(y) > 0:
if y.ndim != 2:
raise ValueError("Supports only 2-dimensional tensor items")
x_padded[i, : y.shape[0], : y.shape[1]] = y
slices = (i, *(slice(0, y.shape[dim]) for dim in range(y.ndim)))
x_padded[slices] = y
return x_padded
def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None):
def padded_to_list(
x: torch.Tensor,
split_size: Union[Sequence[int], Sequence[Sequence[int]], None] = None,
):
r"""
Transforms a padded tensor of shape (N, M, K) into a list of N tensors
of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape
(M, K) if split_size is None.
Support only for 3-dimensional input tensor.
Transforms a padded tensor of shape (N, S_1, S_2, ..., S_D) into a list
of N tensors of shape:
- (Si_1, Si_2, ..., Si_D) where (Si_1, Si_2, ..., Si_D) is specified in split_size(i)
- or (S_1, S_2, ..., S_D) if split_size is None
- or (Si_1, S_2, ..., S_D) if split_size(i) is an integer.
Args:
x: tensor
split_size: list, tuple or int defining the number of items for each tensor
in the output list.
split_size: optional 1D or 2D list/tuple of ints defining the number of
items for each tensor.
Returns:
x_list: a list of tensors
x_list: a list of tensors sharing the memory with the input.
"""
if x.ndim != 3:
raise ValueError("Supports only 3-dimensional input tensors")
x_list = list(x.unbind(0))
if split_size is None:
@ -84,13 +104,9 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
for i in range(N):
if isinstance(split_size[i], int):
x_list[i] = x_list[i][: split_size[i]]
elif len(split_size[i]) == 2:
x_list[i] = x_list[i][: split_size[i][0], : split_size[i][1]]
else:
raise ValueError(
"Support only for 2-dimensional unbinded tensor. \
Split size for more dimensions provided"
)
slices = tuple(slice(0, s) for s in split_size[i]) # pyre-ignore
x_list[i] = x_list[i][slices]
return x_list

View File

@ -9,42 +9,74 @@ from pytorch3d.structures import utils as struct_utils
class TestStructUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(43)
def _check_list_to_padded_slices(self, x, x_padded, ndim):
N = len(x)
for i in range(N):
slices = [i]
for dim in range(ndim):
if x[i].nelement() == 0 and x[i].ndim == 1:
slice_ = slice(0, 0, 1)
else:
slice_ = slice(0, x[i].shape[dim], 1)
slices.append(slice_)
if x[i].nelement() == 0 and x[i].ndim == 1:
x_correct = x[i].new_zeros(*[[0] * ndim])
else:
x_correct = x[i]
self.assertClose(x_padded[slices], x_correct)
def test_list_to_padded(self):
device = torch.device("cuda:0")
N = 5
K = 20
ndim = 2
x = []
for _ in range(N):
dims = torch.randint(K, size=(ndim,)).tolist()
x.append(torch.rand(dims, device=device))
pad_size = [K] * ndim
x_padded = struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False
)
for ndim in [1, 2, 3, 4]:
x = []
for _ in range(N):
dims = torch.randint(K, size=(ndim,)).tolist()
x.append(torch.rand(dims, device=device))
self.assertEqual(x_padded.shape[1], K)
self.assertEqual(x_padded.shape[2], K)
for i in range(N):
self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i])
# set 0th element to an empty 1D tensor
x[0] = torch.tensor([], dtype=x[0].dtype, device=device)
# check for no pad size (defaults to max dimension)
x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False)
max_size0 = max(y.shape[0] for y in x)
max_size1 = max(y.shape[1] for y in x)
self.assertEqual(x_padded.shape[1], max_size0)
self.assertEqual(x_padded.shape[2], max_size1)
for i in range(N):
self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i])
# set 1st element to an empty tensor with correct number of dims
x[1] = x[1].new_zeros(*[[0] * ndim])
# check for equisized
x = [torch.rand((K, 10), device=device) for _ in range(N)]
x_padded = struct_utils.list_to_padded(x, equisized=True)
self.assertClose(x_padded, torch.stack(x, 0))
pad_size = [K] * ndim
x_padded = struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False
)
for dim in range(ndim):
self.assertEqual(x_padded.shape[dim + 1], K)
self._check_list_to_padded_slices(x, x_padded, ndim)
# check for no pad size (defaults to max dimension)
x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False)
max_sizes = (
max(
(0 if (y.nelement() == 0 and y.ndim == 1) else y.shape[dim])
for y in x
)
for dim in range(ndim)
)
for dim, max_size in enumerate(max_sizes):
self.assertEqual(x_padded.shape[dim + 1], max_size)
self._check_list_to_padded_slices(x, x_padded, ndim)
# check for equisized
x = [torch.rand((K, *([10] * (ndim - 1))), device=device) for _ in range(N)]
x_padded = struct_utils.list_to_padded(x, equisized=True)
self.assertClose(x_padded, torch.stack(x, 0))
# catch ValueError for invalid dimensions
with self.assertRaisesRegex(ValueError, "Pad size must"):
pad_size = [K] * 4
pad_size = [K] * (ndim + 1)
struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False
)
@ -56,7 +88,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
dims = torch.randint(K, size=(ndim,)).tolist()
x.append(torch.rand(dims, device=device))
pad_size = [K] * 2
with self.assertRaisesRegex(ValueError, "Supports only"):
with self.assertRaisesRegex(ValueError, "Pad size must"):
x_padded = struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False
)
@ -66,27 +98,29 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
N = 5
K = 20
ndim = 2
dims = [K] * ndim
x = torch.rand([N] + dims, device=device)
x_list = struct_utils.padded_to_list(x)
for i in range(N):
self.assertClose(x_list[i], x[i])
for ndim in (2, 3, 4):
split_size = torch.randint(1, K, size=(N,)).tolist()
x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N):
self.assertClose(x_list[i], x[i, : split_size[i]])
dims = [K] * ndim
x = torch.rand([N] + dims, device=device)
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N):
self.assertClose(x_list[i], x[i, : split_size[i][0], : split_size[i][1]])
x_list = struct_utils.padded_to_list(x)
for i in range(N):
self.assertClose(x_list[i], x[i])
with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size)
split_size = torch.randint(1, K, size=(N, ndim)).unbind(0)
x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N):
slices = [i]
for dim in range(ndim):
slices.append(slice(0, split_size[i][dim], 1))
self.assertClose(x_list[i], x[slices])
# split size is a list of ints
split_size = [int(z) for z in torch.randint(1, K, size=(N,)).unbind(0)]
x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N):
self.assertClose(x_list[i], x[i][: split_size[i]])
def test_padded_to_packed(self):
device = torch.device("cuda:0")
@ -160,7 +194,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size)
struct_utils.padded_to_packed(x, split_size=split_size)
def test_list_to_packed(self):
device = torch.device("cuda:0")