From 20e457ca0e651c56bfe24354295591816da461aa Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Sun, 15 Mar 2020 09:32:59 -0700 Subject: [PATCH] [pytorch3d[ padded to packed function in struct utils Summary: Added a padded to packed utils function which takes either split sizes or a padding value to remove padded elements from a tensor. Reviewed By: gkioxari Differential Revision: D20454238 fbshipit-source-id: 180b807ff44c74c4ee9d5c1ac3b5c4a9b4be57c7 --- pytorch3d/structures/utils.py | 81 ++++++++++++++++++++++++++++++++- tests/test_struct_utils.py | 86 +++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 2 deletions(-) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index 726bb076..9fc1c81d 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -66,11 +66,16 @@ def padded_to_list( Args: x: tensor - split_size: the shape of the final tensor to be returned (of length N). + split_size: list, tuple or int defining the number of items for each tensor + in the output list. + + Returns: + x_list: a list of tensors """ if x.ndim != 3: raise ValueError("Supports only 3-dimensional input tensors") x_list = list(x.unbind(0)) + if split_size is None: return x_list @@ -141,9 +146,81 @@ def packed_to_list(x: torch.Tensor, split_size: Union[list, int]): Args: x: tensor - split_size: list or int defining the number of items for each split + split_size: list, tuple or int defining the number of items for each tensor + in the output list. Returns: x_list: A list of Tensors """ return x.split(split_size, dim=0) + + +def padded_to_packed( + x: torch.Tensor, + split_size: Union[list, tuple, None] = None, + pad_value: Union[float, int, None] = None, +): + r""" + Transforms a padded tensor of shape (N, M, K) into a packed tensor + of shape: + - (sum(Mi), K) where (Mi, K) are the dimensions of + each of the tensors in the batch and Mi is specified by split_size(i) + - (N*M, K) if split_size is None + + Support only for 3-dimensional input tensor and 1-dimensional split size. + + Args: + x: tensor + split_size: list, tuple or int defining the number of items for each tensor + in the output list. + pad_value: optional value to use to filter the padded values in the input + tensor. + + Only one of split_size or pad_value should be provided, or both can be None. + + Returns: + x_packed: a packed tensor. + """ + if x.ndim != 3: + raise ValueError("Supports only 3-dimensional input tensors") + + N, M, D = x.shape + + if split_size is not None and pad_value is not None: + raise ValueError( + "Only one of split_size or pad_value should be provided." + ) + + x_packed = x.view(-1, D) # flatten padded + + if pad_value is None and split_size is None: + return x_packed + + # Convert to packed using pad value + if pad_value is not None: + mask = x_packed.ne(pad_value).any(-1) + x_packed = x_packed[mask] + return x_packed + + # Convert to packed using split sizes + N = len(split_size) + if x.shape[0] != N: + raise ValueError( + "Split size must be of same length as inputs first dimension" + ) + + if not all(isinstance(i, int) for i in split_size): + raise ValueError( + "Support only 1-dimensional unbinded tensor. \ + Split size for more dimensions provided" + ) + + padded_to_packed_idx = torch.cat( + [ + torch.arange(v, dtype=torch.int64, device=x.device) + i * M + for (i, v) in enumerate(split_size) + ], + dim=0, + ) + + return x_packed[padded_to_packed_idx] diff --git a/tests/test_struct_utils.py b/tests/test_struct_utils.py index 5b63a8b9..6f2b5f7f 100644 --- a/tests/test_struct_utils.py +++ b/tests/test_struct_utils.py @@ -97,6 +97,92 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase): split_size = torch.randint(1, K, size=(N,)).tolist() struct_utils.padded_to_list(x, split_size) + def test_padded_to_packed(self): + device = torch.device("cuda:0") + N = 5 + K = 20 + ndim = 2 + dims = [K] * ndim + x = torch.rand([N] + dims, device=device) + + # Case 1: no split_size or pad_value provided + # Check output is just the flattened input. + x_packed = struct_utils.padded_to_packed(x) + self.assertTrue(x_packed.shape == (x.shape[0] * x.shape[1], x.shape[2])) + self.assertClose(x_packed, x.reshape(-1, K)) + + # Case 2: pad_value is provided. + # Check each section of the packed tensor matches the + # corresponding unpadded elements of the padded tensor. + # Check that only rows where all the values are padded + # are removed in the conversion to packed. + pad_value = -1 + x_list = [] + split_size = [] + for _ in range(N): + dim = torch.randint(K, size=(1,)).item() + # Add some random values in the input which are the same as the pad_value. + # These should not be filtered out. + x_list.append( + torch.randint( + low=pad_value, high=10, size=(dim, K), device=device + ) + ) + split_size.append(dim) + x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value) + x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value) + curr = 0 + for i in range(N): + self.assertClose( + x_packed[curr : curr + split_size[i], ...], x_list[i] + ) + self.assertClose(torch.cat(x_list), x_packed) + curr += split_size[i] + + # Case 3: split_size is provided. + # Check each section of the packed tensor matches the corresponding + # unpadded elements. + x_packed = struct_utils.padded_to_packed( + x_padded, split_size=split_size + ) + curr = 0 + for i in range(N): + self.assertClose( + x_packed[curr : curr + split_size[i], ...], x_list[i] + ) + self.assertClose(torch.cat(x_list), x_packed) + curr += split_size[i] + + # Case 4: split_size of the wrong shape is provided. + # Raise an error. + split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0) + with self.assertRaisesRegex(ValueError, "1-dimensional"): + x_packed = struct_utils.padded_to_packed( + x_padded, split_size=split_size + ) + + split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist() + with self.assertRaisesRegex( + ValueError, "same length as inputs first dimension" + ): + x_packed = struct_utils.padded_to_packed( + x_padded, split_size=split_size + ) + + # Case 5: both pad_value and split_size are provided. + # Raise an error. + with self.assertRaisesRegex(ValueError, "Only one of"): + x_packed = struct_utils.padded_to_packed( + x_padded, split_size=split_size, pad_value=-1 + ) + + # Case 6: Input has more than 3 dims. + # Raise an error. + with self.assertRaisesRegex(ValueError, "Supports only"): + x = torch.rand((N, K, K, K, K), device=device) + split_size = torch.randint(1, K, size=(N,)).tolist() + struct_utils.padded_to_list(x, split_size) + def test_list_to_packed(self): device = torch.device("cuda:0") N = 5