diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index 5da55d70..bfe376e9 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -1,77 +1,97 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import List, Union +from typing import List, Sequence, Union import torch """ -Util functions containing representation transforms for points/verts/faces. +Util functions for points/verts/faces/volumes. """ def list_to_padded( x: List[torch.Tensor], - pad_size: Union[list, tuple, None] = None, + pad_size: Union[Sequence[int], None] = None, pad_value: float = 0.0, equisized: bool = False, ) -> torch.Tensor: r""" - Transforms a list of N tensors each of shape (Mi, Ki) into a single tensor - of shape (N, pad_size(0), pad_size(1)), or (N, max(Mi), max(Ki)) - if pad_size is None. + Transforms a list of N tensors each of shape (Si_0, Si_1, ... Si_D) + into: + - a single tensor of shape (N, pad_size(0), pad_size(1), ..., pad_size(D)) + if pad_size is provided + - or a tensor of shape (N, max(Si_0), max(Si_1), ..., max(Si_D)) if pad_size is None. Args: x: list of Tensors - pad_size: list(int) specifying the size of the padded tensor + pad_size: list(int) specifying the size of the padded tensor. + If `None` (default), the largest size of each dimension + is set as the `pad_size`. pad_value: float value to be used to fill the padded tensor equisized: bool indicating whether the items in x are of equal size (sometimes this is known and if provided saves computation) Returns: - x_padded: tensor consisting of padded input tensors + x_padded: tensor consisting of padded input tensors stored + over the newly allocated memory. """ if equisized: return torch.stack(x, 0) + if not all(torch.is_tensor(y) for y in x): + raise ValueError("All items have to be instances of a torch.Tensor.") + + # we set the common number of dimensions to the maximum + # of the dimensionalities of the tensors in the list + element_ndim = max(y.ndim for y in x) + + # replace empty 1D tensors with empty tensors with a correct number of dimensions + x = [ + (y.new_zeros([0] * element_ndim) if (y.ndim == 1 and y.nelement() == 0) else y) + for y in x + ] # pyre-ignore + + if any(y.ndim != x[0].ndim for y in x): + raise ValueError("All items have to have the same number of dimensions!") + if pad_size is None: - pad_dim0 = max(y.shape[0] for y in x if len(y) > 0) - pad_dim1 = max(y.shape[1] for y in x if len(y) > 0) + pad_dims = [ + max(y.shape[dim] for y in x if len(y) > 0) for dim in range(x[0].ndim) + ] else: - if len(pad_size) != 2: - raise ValueError("Pad size must contain target size for 1st and 2nd dim") - pad_dim0, pad_dim1 = pad_size + if any(len(pad_size) != y.ndim for y in x): + raise ValueError("Pad size must contain target size for all dimensions.") + pad_dims = pad_size N = len(x) - x_padded = torch.full( - (N, pad_dim0, pad_dim1), pad_value, dtype=x[0].dtype, device=x[0].device - ) + x_padded = x[0].new_full((N, *pad_dims), pad_value) for i, y in enumerate(x): if len(y) > 0: - if y.ndim != 2: - raise ValueError("Supports only 2-dimensional tensor items") - x_padded[i, : y.shape[0], : y.shape[1]] = y + slices = (i, *(slice(0, y.shape[dim]) for dim in range(y.ndim))) + x_padded[slices] = y return x_padded -def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None): +def padded_to_list( + x: torch.Tensor, + split_size: Union[Sequence[int], Sequence[Sequence[int]], None] = None, +): r""" - Transforms a padded tensor of shape (N, M, K) into a list of N tensors - of shape (Mi, Ki) where (Mi, Ki) is specified in split_size(i), or of shape - (M, K) if split_size is None. - Support only for 3-dimensional input tensor. + Transforms a padded tensor of shape (N, S_1, S_2, ..., S_D) into a list + of N tensors of shape: + - (Si_1, Si_2, ..., Si_D) where (Si_1, Si_2, ..., Si_D) is specified in split_size(i) + - or (S_1, S_2, ..., S_D) if split_size is None + - or (Si_1, S_2, ..., S_D) if split_size(i) is an integer. Args: x: tensor - split_size: list, tuple or int defining the number of items for each tensor - in the output list. + split_size: optional 1D or 2D list/tuple of ints defining the number of + items for each tensor. Returns: - x_list: a list of tensors + x_list: a list of tensors sharing the memory with the input. """ - if x.ndim != 3: - raise ValueError("Supports only 3-dimensional input tensors") - x_list = list(x.unbind(0)) if split_size is None: @@ -84,13 +104,9 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None) for i in range(N): if isinstance(split_size[i], int): x_list[i] = x_list[i][: split_size[i]] - elif len(split_size[i]) == 2: - x_list[i] = x_list[i][: split_size[i][0], : split_size[i][1]] else: - raise ValueError( - "Support only for 2-dimensional unbinded tensor. \ - Split size for more dimensions provided" - ) + slices = tuple(slice(0, s) for s in split_size[i]) # pyre-ignore + x_list[i] = x_list[i][slices] return x_list diff --git a/tests/test_struct_utils.py b/tests/test_struct_utils.py index 4d555c70..dfb26740 100644 --- a/tests/test_struct_utils.py +++ b/tests/test_struct_utils.py @@ -9,42 +9,74 @@ from pytorch3d.structures import utils as struct_utils class TestStructUtils(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(43) + + def _check_list_to_padded_slices(self, x, x_padded, ndim): + N = len(x) + for i in range(N): + slices = [i] + for dim in range(ndim): + if x[i].nelement() == 0 and x[i].ndim == 1: + slice_ = slice(0, 0, 1) + else: + slice_ = slice(0, x[i].shape[dim], 1) + slices.append(slice_) + if x[i].nelement() == 0 and x[i].ndim == 1: + x_correct = x[i].new_zeros(*[[0] * ndim]) + else: + x_correct = x[i] + self.assertClose(x_padded[slices], x_correct) + def test_list_to_padded(self): device = torch.device("cuda:0") N = 5 K = 20 - ndim = 2 - x = [] - for _ in range(N): - dims = torch.randint(K, size=(ndim,)).tolist() - x.append(torch.rand(dims, device=device)) - pad_size = [K] * ndim - x_padded = struct_utils.list_to_padded( - x, pad_size=pad_size, pad_value=0.0, equisized=False - ) + for ndim in [1, 2, 3, 4]: + x = [] + for _ in range(N): + dims = torch.randint(K, size=(ndim,)).tolist() + x.append(torch.rand(dims, device=device)) - self.assertEqual(x_padded.shape[1], K) - self.assertEqual(x_padded.shape[2], K) - for i in range(N): - self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]) + # set 0th element to an empty 1D tensor + x[0] = torch.tensor([], dtype=x[0].dtype, device=device) - # check for no pad size (defaults to max dimension) - x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False) - max_size0 = max(y.shape[0] for y in x) - max_size1 = max(y.shape[1] for y in x) - self.assertEqual(x_padded.shape[1], max_size0) - self.assertEqual(x_padded.shape[2], max_size1) - for i in range(N): - self.assertClose(x_padded[i, : x[i].shape[0], : x[i].shape[1]], x[i]) + # set 1st element to an empty tensor with correct number of dims + x[1] = x[1].new_zeros(*[[0] * ndim]) - # check for equisized - x = [torch.rand((K, 10), device=device) for _ in range(N)] - x_padded = struct_utils.list_to_padded(x, equisized=True) - self.assertClose(x_padded, torch.stack(x, 0)) + pad_size = [K] * ndim + x_padded = struct_utils.list_to_padded( + x, pad_size=pad_size, pad_value=0.0, equisized=False + ) + + for dim in range(ndim): + self.assertEqual(x_padded.shape[dim + 1], K) + + self._check_list_to_padded_slices(x, x_padded, ndim) + + # check for no pad size (defaults to max dimension) + x_padded = struct_utils.list_to_padded(x, pad_value=0.0, equisized=False) + max_sizes = ( + max( + (0 if (y.nelement() == 0 and y.ndim == 1) else y.shape[dim]) + for y in x + ) + for dim in range(ndim) + ) + for dim, max_size in enumerate(max_sizes): + self.assertEqual(x_padded.shape[dim + 1], max_size) + + self._check_list_to_padded_slices(x, x_padded, ndim) + + # check for equisized + x = [torch.rand((K, *([10] * (ndim - 1))), device=device) for _ in range(N)] + x_padded = struct_utils.list_to_padded(x, equisized=True) + self.assertClose(x_padded, torch.stack(x, 0)) # catch ValueError for invalid dimensions with self.assertRaisesRegex(ValueError, "Pad size must"): - pad_size = [K] * 4 + pad_size = [K] * (ndim + 1) struct_utils.list_to_padded( x, pad_size=pad_size, pad_value=0.0, equisized=False ) @@ -56,7 +88,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): dims = torch.randint(K, size=(ndim,)).tolist() x.append(torch.rand(dims, device=device)) pad_size = [K] * 2 - with self.assertRaisesRegex(ValueError, "Supports only"): + with self.assertRaisesRegex(ValueError, "Pad size must"): x_padded = struct_utils.list_to_padded( x, pad_size=pad_size, pad_value=0.0, equisized=False ) @@ -66,27 +98,29 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): N = 5 K = 20 ndim = 2 - dims = [K] * ndim - x = torch.rand([N] + dims, device=device) - x_list = struct_utils.padded_to_list(x) - for i in range(N): - self.assertClose(x_list[i], x[i]) + for ndim in (2, 3, 4): - split_size = torch.randint(1, K, size=(N,)).tolist() - x_list = struct_utils.padded_to_list(x, split_size) - for i in range(N): - self.assertClose(x_list[i], x[i, : split_size[i]]) + dims = [K] * ndim + x = torch.rand([N] + dims, device=device) - split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0) - x_list = struct_utils.padded_to_list(x, split_size) - for i in range(N): - self.assertClose(x_list[i], x[i, : split_size[i][0], : split_size[i][1]]) + x_list = struct_utils.padded_to_list(x) + for i in range(N): + self.assertClose(x_list[i], x[i]) - with self.assertRaisesRegex(ValueError, "Supports only"): - x = torch.rand((N, K, K, K, K), device=device) - split_size = torch.randint(1, K, size=(N,)).tolist() - struct_utils.padded_to_list(x, split_size) + split_size = torch.randint(1, K, size=(N, ndim)).unbind(0) + x_list = struct_utils.padded_to_list(x, split_size) + for i in range(N): + slices = [i] + for dim in range(ndim): + slices.append(slice(0, split_size[i][dim], 1)) + self.assertClose(x_list[i], x[slices]) + + # split size is a list of ints + split_size = [int(z) for z in torch.randint(1, K, size=(N,)).unbind(0)] + x_list = struct_utils.padded_to_list(x, split_size) + for i in range(N): + self.assertClose(x_list[i], x[i][: split_size[i]]) def test_padded_to_packed(self): device = torch.device("cuda:0") @@ -160,7 +194,7 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "Supports only"): x = torch.rand((N, K, K, K, K), device=device) split_size = torch.randint(1, K, size=(N,)).tolist() - struct_utils.padded_to_list(x, split_size) + struct_utils.padded_to_packed(x, split_size=split_size) def test_list_to_packed(self): device = torch.device("cuda:0")