Support for multi-dimensional list_to_padded/padded_to_list.

Summary: Extends `list_to_padded`/`padded_to_list` to work for tensors with an arbitrary number of input dimensions.

Reviewed By: nikhilaravi, gkioxari

Differential Revision: D23813969

fbshipit-source-id: 52c212a2ecdb3c4dfb6ac47217715e07998f37f1
This commit is contained in:
David Novotny 2021-01-04 09:41:28 -08:00 committed by Facebook GitHub Bot
parent 0ba55a83ad
commit b4dea43963
2 changed files with 131 additions and 81 deletions

View File

@ -1,77 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Union from typing import List, Sequence, Union
import torch import torch
""" """
Util functions containing representation transforms for points/verts/faces. Util functions for points/verts/faces/volumes.
""" """
def list_to_padded( def list_to_padded(
x: List[torch.Tensor], x: List[torch.Tensor],
pad_size: Union[list, tuple, None] = None, pad_size: Union[Sequence[int], None] = None,
pad_value: float = 0.0, pad_value: float = 0.0,
equisized: bool = False, equisized: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Transforms a list of N tensors each of shape (Mi, Ki) into a single tensor Transforms a list of N tensors each of shape (Si_0, Si_1, ... Si_D)
of shape (N, pad_size(0), pad_size(1)), or (N, max(Mi), max(Ki)) into:
if pad_size is None. - a single tensor of shape (N, pad_size(0), pad_size(1), ..., pad_size(D))
if pad_size is provided
- or a tensor of shape (N, max(Si_0), max(Si_1), ..., max(Si_D)) if pad_size is None.
Args: Args:
x: list of Tensors x: list of Tensors
pad_size: list(int) specifying the size of the padded tensor pad_size: list(int) specifying the size of the padded tensor.
If `None` (default), the largest size of each dimension
is set as the `pad_size`.
pad_value: float value to be used to fill the padded tensor pad_value: float value to be used to fill the padded tensor
equisized: bool indicating whether the items in x are of equal size equisized: bool indicating whether the items in x are of equal size
(sometimes this is known and if provided saves computation) (sometimes this is known and if provided saves computation)
Returns: Returns:
x_padded: tensor consisting of padded input tensors x_padded: tensor consisting of padded input tensors stored
over the newly allocated memory.
""" """
if equisized: if equisized:
return torch.stack(x, 0) return torch.stack(x, 0)
if not all(torch.is_tensor(y) for y in x):
raise ValueError("All items have to be instances of a torch.Tensor.")
# we set the common number of dimensions to the maximum
# of the dimensionalities of the tensors in the list
element_ndim = max(y.ndim for y in x)
# replace empty 1D tensors with empty tensors with a correct number of dimensions
x = [
(y.new_zeros([0] * element_ndim) if (y.ndim == 1 and y.nelement() == 0) else y)
for y in x
] # pyre-ignore
if any(y.ndim != x[0].ndim for y in x):
raise ValueError("All items have to have the same number of dimensions!")
if pad_size is None: if pad_size is None:
pad_dim0 = max(y.shape[0] for y in x if len(y) > 0) pad_dims = [
pad_dim1 = max(y.shape[1] for y in x if len(y) > 0) max(y.shape[dim] for y in x if len(y) > 0) for dim in range(x[0].ndim)
]
else: else:
if len(pad_size) != 2: if any(len(pad_size) != y.ndim for y in x):
raise ValueError("Pad size must contain target size for 1st and 2nd dim") raise ValueError("Pad size must contain target size for all dimensions.")
pad_dim0, pad_dim1 = pad_size pad_dims = pad_size
N = len(x) N = len(x)
x_padded = torch.full( x_padded = x[0].new_full((N, *pad_dims), pad_value)
(N, pad_dim0, pad_dim1), pad_value, dtype=x[0].dtype, device=x[0].device
)
for i, y in enumerate(x): for i, y in enumerate(x):
if len(y) > 0: if len(y) > 0:
if y.ndim != 2: slices = (i, *(slice(0, y.shape[dim]) for dim in range(y.ndim)))
raise ValueError("Supports only 2-dimensional tensor items") x_padded[slices] = y
x_padded[i, : y.shape[0], : y.shape[1]] = y
return x_padded return x_padded
def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None): def padded_to_list(
x: torch.Tensor,
split_size: Union[Sequence[int], Sequence[Sequence[int]], None] = None,
):
r""" r"""
Transforms a padded tensor of shape (N, M, K) into a list of N tensors Transforms a padded tensor of shape (N, S_1, S_2, ..., S_D) into a list
of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape of N tensors of shape:
(M, K) if split_size is None. - (Si_1, Si_2, ..., Si_D) where (Si_1, Si_2, ..., Si_D) is specified in split_size(i)
Support only for 3-dimensional input tensor. - or (S_1, S_2, ..., S_D) if split_size is None
- or (Si_1, S_2, ..., S_D) if split_size(i) is an integer.
Args: Args:
x: tensor x: tensor
split_size: list, tuple or int defining the number of items for each tensor split_size: optional 1D or 2D list/tuple of ints defining the number of
in the output list. items for each tensor.
Returns: Returns:
x_list: a list of tensors x_list: a list of tensors sharing the memory with the input.
""" """
if x.ndim != 3:
raise ValueError("Supports only 3-dimensional input tensors")
x_list = list(x.unbind(0)) x_list = list(x.unbind(0))
if split_size is None: if split_size is None:
@ -84,13 +104,9 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
for i in range(N): for i in range(N):
if isinstance(split_size[i], int): if isinstance(split_size[i], int):
x_list[i] = x_list[i][: split_size[i]] x_list[i] = x_list[i][: split_size[i]]
elif len(split_size[i]) == 2:
x_list[i] = x_list[i][: split_size[i][0], : split_size[i][1]]
else: else:
raise ValueError( slices = tuple(slice(0, s) for s in split_size[i]) # pyre-ignore
"Support only for 2-dimensional unbinded tensor. \ x_list[i] = x_list[i][slices]
Split size for more dimensions provided"
)
return x_list return x_list

View File

@ -9,42 +9,74 @@ from pytorch3d.structures import utils as struct_utils
class TestStructUtils(TestCaseMixin, unittest.TestCase): class TestStructUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(43)
def _check_list_to_padded_slices(self, x, x_padded, ndim):
N = len(x)
for i in range(N):
slices = [i]
for dim in range(ndim):
if x[i].nelement() == 0 and x[i].ndim == 1:
slice_ = slice(0, 0, 1)
else:
slice_ = slice(0, x[i].shape[dim], 1)
slices.append(slice_)
if x[i].nelement() == 0 and x[i].ndim == 1:
x_correct = x[i].new_zeros(*[[0] * ndim])
else:
x_correct = x[i]
self.assertClose(x_padded[slices], x_correct)
def test_list_to_padded(self): def test_list_to_padded(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
N = 5 N = 5
K = 20 K = 20
ndim = 2 for ndim in [1, 2, 3, 4]:
x = [] x = []
for _ in range(N): for _ in range(N):
dims = torch.randint(K, size=(ndim,)).tolist() dims = torch.randint(K, size=(ndim,)).tolist()
x.append(torch.rand(dims, device=device)) x.append(torch.rand(dims, device=device))
# set 0th element to an empty 1D tensor
x[0] = torch.tensor([], dtype=x[0].dtype, device=device)
# set 1st element to an empty tensor with correct number of dims
x[1] = x[1].new_zeros(*[[0] * ndim])
pad_size = [K] * ndim pad_size = [K] * ndim
x_padded = struct_utils.list_to_padded( x_padded = struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False x, pad_size=pad_size, pad_value=0.0, equisized=False
) )
self.assertEqual(x_padded.shape[1], K) for dim in range(ndim):
self.assertEqual(x_padded.shape[2], K) self.assertEqual(x_padded.shape[dim + 1], K)
for i in range(N):
self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]) self._check_list_to_padded_slices(x, x_padded, ndim)
# check for no pad size (defaults to max dimension) # check for no pad size (defaults to max dimension)
x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False) x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False)
max_size0 = max(y.shape[0] for y in x) max_sizes = (
max_size1 = max(y.shape[1] for y in x) max(
self.assertEqual(x_padded.shape[1], max_size0) (0 if (y.nelement() == 0 and y.ndim == 1) else y.shape[dim])
self.assertEqual(x_padded.shape[2], max_size1) for y in x
for i in range(N): )
self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]) for dim in range(ndim)
)
for dim, max_size in enumerate(max_sizes):
self.assertEqual(x_padded.shape[dim + 1], max_size)
self._check_list_to_padded_slices(x, x_padded, ndim)
# check for equisized # check for equisized
x = [torch.rand((K, 10), device=device) for _ in range(N)] x = [torch.rand((K, *([10] * (ndim - 1))), device=device) for _ in range(N)]
x_padded = struct_utils.list_to_padded(x, equisized=True) x_padded = struct_utils.list_to_padded(x, equisized=True)
self.assertClose(x_padded, torch.stack(x, 0)) self.assertClose(x_padded, torch.stack(x, 0))
# catch ValueError for invalid dimensions # catch ValueError for invalid dimensions
with self.assertRaisesRegex(ValueError, "Pad size must"): with self.assertRaisesRegex(ValueError, "Pad size must"):
pad_size = [K] * 4 pad_size = [K] * (ndim + 1)
struct_utils.list_to_padded( struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False x, pad_size=pad_size, pad_value=0.0, equisized=False
) )
@ -56,7 +88,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
dims = torch.randint(K, size=(ndim,)).tolist() dims = torch.randint(K, size=(ndim,)).tolist()
x.append(torch.rand(dims, device=device)) x.append(torch.rand(dims, device=device))
pad_size = [K] * 2 pad_size = [K] * 2
with self.assertRaisesRegex(ValueError, "Supports only"): with self.assertRaisesRegex(ValueError, "Pad size must"):
x_padded = struct_utils.list_to_padded( x_padded = struct_utils.list_to_padded(
x, pad_size=pad_size, pad_value=0.0, equisized=False x, pad_size=pad_size, pad_value=0.0, equisized=False
) )
@ -66,6 +98,9 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
N = 5 N = 5
K = 20 K = 20
ndim = 2 ndim = 2
for ndim in (2, 3, 4):
dims = [K] * ndim dims = [K] * ndim
x = torch.rand([N] + dims, device=device) x = torch.rand([N] + dims, device=device)
@ -73,20 +108,19 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
for i in range(N): for i in range(N):
self.assertClose(x_list[i], x[i]) self.assertClose(x_list[i], x[i])
split_size = torch.randint(1, K, size=(N,)).tolist() split_size = torch.randint(1, K, size=(N, ndim)).unbind(0)
x_list = struct_utils.padded_to_list(x, split_size) x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N): for i in range(N):
self.assertClose(x_list[i], x[i, : split_size[i]]) slices = [i]
for dim in range(ndim):
slices.append(slice(0, split_size[i][dim], 1))
self.assertClose(x_list[i], x[slices])
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0) # split size is a list of ints
split_size = [int(z) for z in torch.randint(1, K, size=(N,)).unbind(0)]
x_list = struct_utils.padded_to_list(x, split_size) x_list = struct_utils.padded_to_list(x, split_size)
for i in range(N): for i in range(N):
self.assertClose(x_list[i], x[i, : split_size[i][0], : split_size[i][1]]) self.assertClose(x_list[i], x[i][: split_size[i]])
with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size)
def test_padded_to_packed(self): def test_padded_to_packed(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -160,7 +194,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Supports only"): with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device) x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist() split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size) struct_utils.padded_to_packed(x, split_size=split_size)
def test_list_to_packed(self): def test_list_to_packed(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")