diff --git a/pytorch3d/ops/packed_to_padded.py b/pytorch3d/ops/packed_to_padded.py index dee1363c..8fd2b718 100644 --- a/pytorch3d/ops/packed_to_padded.py +++ b/pytorch3d/ops/packed_to_padded.py @@ -65,7 +65,7 @@ def packed_to_padded(inputs, first_idxs, max_size): Torch wrapper that handles allowed input shapes. See description below. Args: - inputs: FloatTensor of shape (F,) or (F, D), representing the packed + inputs: FloatTensor of shape (F,) or (F, ...), representing the packed batch tensor, e.g. areas for faces in a batch of meshes. first_idxs: LongTensor of shape (N,) where N is the number of elements in the batch and `first_idxs[i] = f` @@ -73,7 +73,7 @@ def packed_to_padded(inputs, first_idxs, max_size): max_size: Max length of an element in the batch. Returns: - inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, D) + inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...) where max_size is max of `sizes`. The values for batch element i which start at `inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`, with zeros padding out the extra inputs. @@ -83,15 +83,20 @@ def packed_to_padded(inputs, first_idxs, max_size): (N, max_size, 1). """ # if inputs is of shape (F,), reshape into (F, 1) - flat = False - if inputs.dim() == 1: - flat = True + input_shape = inputs.shape + n_dims = inputs.dim() + if n_dims == 1: inputs = inputs.unsqueeze(1) + else: + inputs = inputs.reshape(input_shape[0], -1) inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size) # if flat is True, reshape output to (N, max_size) from (N, max_size, 1) - if flat: - inputs_padded = inputs_padded.squeeze(2) - return inputs_padded + # else reshape output to (N, max_size, ...) + if n_dims == 1: + return inputs_padded.squeeze(2) + if n_dims == 2: + return inputs_padded + return inputs_padded.view(*inputs_padded.shape[:2], *input_shape[1:]) class _PaddedToPacked(Function): @@ -147,7 +152,7 @@ def padded_to_packed(inputs, first_idxs, num_inputs): Torch wrapper that handles allowed input shapes. See description below. Args: - inputs: FloatTensor of shape (N, max_size) or (N, max_size, D), + inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...), representing the padded tensor, e.g. areas for faces in a batch of meshes. first_idxs: LongTensor of shape (N,) where N is the number of @@ -156,20 +161,25 @@ def padded_to_packed(inputs, first_idxs, num_inputs): num_inputs: Number of packed entries (= F) Returns: - inputs_packed: FloatTensor of shape (F,) or (F, D) where - `inputs_packed[first_idx[i]:] = inputs[i, :]`. + inputs_packed: FloatTensor of shape (F,) or (F, ...) where + `inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`. To handle the allowed input shapes, we convert the inputs tensor of shape (N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from (F, 1). """ # if inputs is of shape (N, max_size), reshape into (N, max_size, 1)) - flat = False - if inputs.dim() == 2: - flat = True + input_shape = inputs.shape + n_dims = inputs.dim() + if n_dims == 2: inputs = inputs.unsqueeze(2) + else: + inputs = inputs.reshape(*input_shape[:2], -1) inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs) - # if flat is True, reshape output to (F,) from (F, 1) - if flat: - inputs_packed = inputs_packed.squeeze(1) - return inputs_packed + # if input is flat, reshape output to (F,) from (F, 1) + # else reshape output to (F, ...) + if n_dims == 2: + return inputs_packed.squeeze(1) + if n_dims == 3: + return inputs_packed + return inputs_packed.view(-1, *input_shape[2:]) diff --git a/tests/test_packed_to_padded.py b/tests/test_packed_to_padded.py index 31104efd..79dc35c0 100644 --- a/tests/test_packed_to_padded.py +++ b/tests/test_packed_to_padded.py @@ -45,18 +45,19 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): PyTorch implementation of packed_to_padded function. """ num_meshes = first_idxs.size(0) - D = inputs.shape[1] if inputs.dim() == 2 else 0 - if D == 0: + if inputs.dim() == 1: inputs_padded = torch.zeros((num_meshes, max_size), device=device) else: - inputs_padded = torch.zeros((num_meshes, max_size, D), device=device) + inputs_padded = torch.zeros( + (num_meshes, max_size, *inputs.shape[1:]), device=device + ) for m in range(num_meshes): s = first_idxs[m] if m == num_meshes - 1: f = inputs.shape[0] else: f = first_idxs[m + 1] - inputs_padded[m, :f] = inputs[s:f] + inputs_padded[m, : f - s] = inputs[s:f] return inputs_padded @@ -66,22 +67,21 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): PyTorch implementation of padded_to_packed function. """ num_meshes = inputs.size(0) - D = inputs.shape[2] if inputs.dim() == 3 else 0 - if D == 0: + if inputs.dim() == 2: inputs_packed = torch.zeros((num_inputs,), device=device) else: - inputs_packed = torch.zeros((num_inputs, D), device=device) + inputs_packed = torch.zeros((num_inputs, *inputs.shape[2:]), device=device) for m in range(num_meshes): s = first_idxs[m] if m == num_meshes - 1: f = num_inputs else: f = first_idxs[m + 1] - inputs_packed[s:f] = inputs[m, :f] + inputs_packed[s:f] = inputs[m, : f - s] return inputs_packed - def _test_packed_to_padded_helper(self, D, device): + def _test_packed_to_padded_helper(self, dims, device): """ Check the results from packed_to_padded and PyTorch implementations are the same. @@ -91,10 +91,12 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() max_faces = meshes.num_faces_per_mesh().max().item() - if D == 0: + if len(dims) == 0: values = torch.rand((faces.shape[0],), device=device, requires_grad=True) else: - values = torch.rand((faces.shape[0], D), device=device, requires_grad=True) + values = torch.rand( + (faces.shape[0], *dims), device=device, requires_grad=True + ) values_torch = values.detach().clone() values_torch.requires_grad = True values_padded = packed_to_padded( @@ -107,10 +109,10 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self.assertClose(values_padded, values_padded_torch) # check backward - if D == 0: + if len(dims) == 0: grad_inputs = torch.rand((len(meshes), max_faces), device=device) else: - grad_inputs = torch.rand((len(meshes), max_faces, D), device=device) + grad_inputs = torch.rand((len(meshes), max_faces, *dims), device=device) values_padded.backward(grad_inputs) grad_outputs = values.grad values_padded_torch.backward(grad_inputs) @@ -122,27 +124,41 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self.assertClose(grad_outputs, grad_outputs_torch2) def test_packed_to_padded_flat_cpu(self): - self._test_packed_to_padded_helper(0, "cpu") + self._test_packed_to_padded_helper([], "cpu") def test_packed_to_padded_D1_cpu(self): - self._test_packed_to_padded_helper(1, "cpu") + self._test_packed_to_padded_helper([1], "cpu") def test_packed_to_padded_D16_cpu(self): - self._test_packed_to_padded_helper(16, "cpu") + self._test_packed_to_padded_helper([16], "cpu") + + def test_packed_to_padded_D16_9_cpu(self): + self._test_packed_to_padded_helper([16, 9], "cpu") + + def test_packed_to_padded_D16_3_2_cpu(self): + self._test_packed_to_padded_helper([16, 3, 2], "cpu") def test_packed_to_padded_flat_cuda(self): device = get_random_cuda_device() - self._test_packed_to_padded_helper(0, device) + self._test_packed_to_padded_helper([], device) def test_packed_to_padded_D1_cuda(self): device = get_random_cuda_device() - self._test_packed_to_padded_helper(1, device) + self._test_packed_to_padded_helper([1], device) def test_packed_to_padded_D16_cuda(self): device = get_random_cuda_device() - self._test_packed_to_padded_helper(16, device) + self._test_packed_to_padded_helper([16], device) - def _test_padded_to_packed_helper(self, D, device): + def test_packed_to_padded_D16_9_cuda(self): + device = get_random_cuda_device() + self._test_packed_to_padded_helper([16, 9], device) + + def test_packed_to_padded_D16_3_2_cuda(self): + device = get_random_cuda_device() + self._test_packed_to_padded_helper([16, 3, 2], device) + + def _test_padded_to_packed_helper(self, dims, device): """ Check the results from packed_to_padded and PyTorch implementations are the same. @@ -151,10 +167,10 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() num_faces_per_mesh = meshes.num_faces_per_mesh() max_faces = num_faces_per_mesh.max().item() - if D == 0: + if len(dims) == 0: values = torch.rand((len(meshes), max_faces), device=device) else: - values = torch.rand((len(meshes), max_faces, D), device=device) + values = torch.rand((len(meshes), max_faces, *dims), device=device) for i, num in enumerate(num_faces_per_mesh): values[i, num:] = 0 values.requires_grad = True @@ -173,11 +189,11 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self.assertClose(values_packed, values_packed_torch) # check backward - if D == 0: + if len(dims) == 0: grad_inputs = torch.rand((num_faces_per_mesh.sum().item()), device=device) else: grad_inputs = torch.rand( - (num_faces_per_mesh.sum().item(), D), device=device + (num_faces_per_mesh.sum().item(), *dims), device=device ) values_packed.backward(grad_inputs) grad_outputs = values.grad @@ -190,41 +206,39 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase): self.assertClose(grad_outputs, grad_outputs_torch2) def test_padded_to_packed_flat_cpu(self): - self._test_padded_to_packed_helper(0, "cpu") + self._test_padded_to_packed_helper([], "cpu") def test_padded_to_packed_D1_cpu(self): - self._test_padded_to_packed_helper(1, "cpu") + self._test_padded_to_packed_helper([1], "cpu") def test_padded_to_packed_D16_cpu(self): - self._test_padded_to_packed_helper(16, "cpu") + self._test_padded_to_packed_helper([16], "cpu") + + def test_padded_to_packed_D16_9_cpu(self): + self._test_padded_to_packed_helper([16, 9], "cpu") + + def test_padded_to_packed_D16_3_2_cpu(self): + self._test_padded_to_packed_helper([16, 3, 2], "cpu") def test_padded_to_packed_flat_cuda(self): device = get_random_cuda_device() - self._test_padded_to_packed_helper(0, device) + self._test_padded_to_packed_helper([], device) def test_padded_to_packed_D1_cuda(self): device = get_random_cuda_device() - self._test_padded_to_packed_helper(1, device) + self._test_padded_to_packed_helper([1], device) def test_padded_to_packed_D16_cuda(self): device = get_random_cuda_device() - self._test_padded_to_packed_helper(16, device) + self._test_padded_to_packed_helper([16], device) - def test_invalid_inputs_shapes(self, device="cuda:0"): - with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."): - values = torch.rand((100, 50, 2), device=device) - first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) - packed_to_padded(values, first_idxs, 100) + def test_padded_to_packed_D16_9_cuda(self): + device = get_random_cuda_device() + self._test_padded_to_packed_helper([16, 9], device) - with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."): - values = torch.rand((100,), device=device) - first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) - padded_to_packed(values, first_idxs, 20) - - with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."): - values = torch.rand((100, 50, 2, 2), device=device) - first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) - padded_to_packed(values, first_idxs, 20) + def test_padded_to_packed_D16_3_2_cuda(self): + device = get_random_cuda_device() + self._test_padded_to_packed_helper([16, 3, 2], device) @staticmethod def packed_to_padded_with_init(