mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	packed_to_padded now accepts all sizes
Summary: We need to make packing/unpacking in 2 places for mixed frame raysampling (metrics and raysampler) but those tensors that need to be unpacked/packed have more than two dimensions. I could have reshaped and stored dimensions but this seems to just complicate code there with something which packed_to_padded should support. I could have made a separate function for implicitron but it would confusing to have two different padded_to_packed functions inside pytorch3d codebase one of which does packing for (b, max) and (b, max, f) and the other for (b, max, …) Reviewed By: bottler Differential Revision: D39729026 fbshipit-source-id: 2bdebf290dcc6c316b7fe1aeee49bbb5255e508c
This commit is contained in:
		
							parent
							
								
									c2d876c9e8
								
							
						
					
					
						commit
						f34da3d3b6
					
				@ -65,7 +65,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
 | 
			
		||||
    Torch wrapper that handles allowed input shapes. See description below.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        inputs: FloatTensor of shape (F,) or (F, D), representing the packed
 | 
			
		||||
        inputs: FloatTensor of shape (F,) or (F, ...), representing the packed
 | 
			
		||||
            batch tensor, e.g. areas for faces in a batch of meshes.
 | 
			
		||||
        first_idxs: LongTensor of shape (N,) where N is the number of
 | 
			
		||||
            elements in the batch and `first_idxs[i] = f`
 | 
			
		||||
@ -73,7 +73,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
 | 
			
		||||
        max_size: Max length of an element in the batch.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, D)
 | 
			
		||||
        inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...)
 | 
			
		||||
            where max_size is  max of `sizes`. The values for batch element i
 | 
			
		||||
            which start at `inputs[first_idxs[i]]` will be copied to
 | 
			
		||||
            `inputs_padded[i, :]`, with zeros padding out the extra inputs.
 | 
			
		||||
@ -83,15 +83,20 @@ def packed_to_padded(inputs, first_idxs, max_size):
 | 
			
		||||
    (N, max_size, 1).
 | 
			
		||||
    """
 | 
			
		||||
    # if inputs is of shape (F,), reshape into (F, 1)
 | 
			
		||||
    flat = False
 | 
			
		||||
    if inputs.dim() == 1:
 | 
			
		||||
        flat = True
 | 
			
		||||
    input_shape = inputs.shape
 | 
			
		||||
    n_dims = inputs.dim()
 | 
			
		||||
    if n_dims == 1:
 | 
			
		||||
        inputs = inputs.unsqueeze(1)
 | 
			
		||||
    else:
 | 
			
		||||
        inputs = inputs.reshape(input_shape[0], -1)
 | 
			
		||||
    inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size)
 | 
			
		||||
    # if flat is True, reshape output to (N, max_size) from (N, max_size, 1)
 | 
			
		||||
    if flat:
 | 
			
		||||
        inputs_padded = inputs_padded.squeeze(2)
 | 
			
		||||
    return inputs_padded
 | 
			
		||||
    # else reshape output to (N, max_size, ...)
 | 
			
		||||
    if n_dims == 1:
 | 
			
		||||
        return inputs_padded.squeeze(2)
 | 
			
		||||
    if n_dims == 2:
 | 
			
		||||
        return inputs_padded
 | 
			
		||||
    return inputs_padded.view(*inputs_padded.shape[:2], *input_shape[1:])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _PaddedToPacked(Function):
 | 
			
		||||
@ -147,7 +152,7 @@ def padded_to_packed(inputs, first_idxs, num_inputs):
 | 
			
		||||
    Torch wrapper that handles allowed input shapes. See description below.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        inputs: FloatTensor of shape (N, max_size) or (N, max_size, D),
 | 
			
		||||
        inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...),
 | 
			
		||||
            representing the padded tensor, e.g. areas for faces in a batch of
 | 
			
		||||
            meshes.
 | 
			
		||||
        first_idxs: LongTensor of shape (N,) where N is the number of
 | 
			
		||||
@ -156,20 +161,25 @@ def padded_to_packed(inputs, first_idxs, num_inputs):
 | 
			
		||||
        num_inputs: Number of packed entries (= F)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        inputs_packed: FloatTensor of shape (F,) or (F, D) where
 | 
			
		||||
            `inputs_packed[first_idx[i]:] = inputs[i, :]`.
 | 
			
		||||
        inputs_packed: FloatTensor of shape (F,) or (F, ...) where
 | 
			
		||||
            `inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`.
 | 
			
		||||
 | 
			
		||||
    To handle the allowed input shapes, we convert the inputs tensor of shape
 | 
			
		||||
    (N, max_size)  to (N, max_size, 1). We reshape the output back to (F,) from
 | 
			
		||||
    (F, 1).
 | 
			
		||||
    """
 | 
			
		||||
    # if inputs is of shape (N, max_size), reshape into (N, max_size, 1))
 | 
			
		||||
    flat = False
 | 
			
		||||
    if inputs.dim() == 2:
 | 
			
		||||
        flat = True
 | 
			
		||||
    input_shape = inputs.shape
 | 
			
		||||
    n_dims = inputs.dim()
 | 
			
		||||
    if n_dims == 2:
 | 
			
		||||
        inputs = inputs.unsqueeze(2)
 | 
			
		||||
    else:
 | 
			
		||||
        inputs = inputs.reshape(*input_shape[:2], -1)
 | 
			
		||||
    inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs)
 | 
			
		||||
    # if flat is True, reshape output to (F,) from (F, 1)
 | 
			
		||||
    if flat:
 | 
			
		||||
        inputs_packed = inputs_packed.squeeze(1)
 | 
			
		||||
    return inputs_packed
 | 
			
		||||
    # if input is flat, reshape output to (F,) from (F, 1)
 | 
			
		||||
    # else reshape output to (F, ...)
 | 
			
		||||
    if n_dims == 2:
 | 
			
		||||
        return inputs_packed.squeeze(1)
 | 
			
		||||
    if n_dims == 3:
 | 
			
		||||
        return inputs_packed
 | 
			
		||||
    return inputs_packed.view(-1, *input_shape[2:])
 | 
			
		||||
 | 
			
		||||
@ -45,18 +45,19 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        PyTorch implementation of packed_to_padded function.
 | 
			
		||||
        """
 | 
			
		||||
        num_meshes = first_idxs.size(0)
 | 
			
		||||
        D = inputs.shape[1] if inputs.dim() == 2 else 0
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if inputs.dim() == 1:
 | 
			
		||||
            inputs_padded = torch.zeros((num_meshes, max_size), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            inputs_padded = torch.zeros((num_meshes, max_size, D), device=device)
 | 
			
		||||
            inputs_padded = torch.zeros(
 | 
			
		||||
                (num_meshes, max_size, *inputs.shape[1:]), device=device
 | 
			
		||||
            )
 | 
			
		||||
        for m in range(num_meshes):
 | 
			
		||||
            s = first_idxs[m]
 | 
			
		||||
            if m == num_meshes - 1:
 | 
			
		||||
                f = inputs.shape[0]
 | 
			
		||||
            else:
 | 
			
		||||
                f = first_idxs[m + 1]
 | 
			
		||||
            inputs_padded[m, :f] = inputs[s:f]
 | 
			
		||||
            inputs_padded[m, : f - s] = inputs[s:f]
 | 
			
		||||
 | 
			
		||||
        return inputs_padded
 | 
			
		||||
 | 
			
		||||
@ -66,22 +67,21 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        PyTorch implementation of padded_to_packed function.
 | 
			
		||||
        """
 | 
			
		||||
        num_meshes = inputs.size(0)
 | 
			
		||||
        D = inputs.shape[2] if inputs.dim() == 3 else 0
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if inputs.dim() == 2:
 | 
			
		||||
            inputs_packed = torch.zeros((num_inputs,), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            inputs_packed = torch.zeros((num_inputs, D), device=device)
 | 
			
		||||
            inputs_packed = torch.zeros((num_inputs, *inputs.shape[2:]), device=device)
 | 
			
		||||
        for m in range(num_meshes):
 | 
			
		||||
            s = first_idxs[m]
 | 
			
		||||
            if m == num_meshes - 1:
 | 
			
		||||
                f = num_inputs
 | 
			
		||||
            else:
 | 
			
		||||
                f = first_idxs[m + 1]
 | 
			
		||||
            inputs_packed[s:f] = inputs[m, :f]
 | 
			
		||||
            inputs_packed[s:f] = inputs[m, : f - s]
 | 
			
		||||
 | 
			
		||||
        return inputs_packed
 | 
			
		||||
 | 
			
		||||
    def _test_packed_to_padded_helper(self, D, device):
 | 
			
		||||
    def _test_packed_to_padded_helper(self, dims, device):
 | 
			
		||||
        """
 | 
			
		||||
        Check the results from packed_to_padded and PyTorch implementations
 | 
			
		||||
        are the same.
 | 
			
		||||
@ -91,10 +91,12 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
 | 
			
		||||
        max_faces = meshes.num_faces_per_mesh().max().item()
 | 
			
		||||
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if len(dims) == 0:
 | 
			
		||||
            values = torch.rand((faces.shape[0],), device=device, requires_grad=True)
 | 
			
		||||
        else:
 | 
			
		||||
            values = torch.rand((faces.shape[0], D), device=device, requires_grad=True)
 | 
			
		||||
            values = torch.rand(
 | 
			
		||||
                (faces.shape[0], *dims), device=device, requires_grad=True
 | 
			
		||||
            )
 | 
			
		||||
        values_torch = values.detach().clone()
 | 
			
		||||
        values_torch.requires_grad = True
 | 
			
		||||
        values_padded = packed_to_padded(
 | 
			
		||||
@ -107,10 +109,10 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(values_padded, values_padded_torch)
 | 
			
		||||
 | 
			
		||||
        # check backward
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if len(dims) == 0:
 | 
			
		||||
            grad_inputs = torch.rand((len(meshes), max_faces), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            grad_inputs = torch.rand((len(meshes), max_faces, D), device=device)
 | 
			
		||||
            grad_inputs = torch.rand((len(meshes), max_faces, *dims), device=device)
 | 
			
		||||
        values_padded.backward(grad_inputs)
 | 
			
		||||
        grad_outputs = values.grad
 | 
			
		||||
        values_padded_torch.backward(grad_inputs)
 | 
			
		||||
@ -122,27 +124,41 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(grad_outputs, grad_outputs_torch2)
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_flat_cpu(self):
 | 
			
		||||
        self._test_packed_to_padded_helper(0, "cpu")
 | 
			
		||||
        self._test_packed_to_padded_helper([], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D1_cpu(self):
 | 
			
		||||
        self._test_packed_to_padded_helper(1, "cpu")
 | 
			
		||||
        self._test_packed_to_padded_helper([1], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D16_cpu(self):
 | 
			
		||||
        self._test_packed_to_padded_helper(16, "cpu")
 | 
			
		||||
        self._test_packed_to_padded_helper([16], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D16_9_cpu(self):
 | 
			
		||||
        self._test_packed_to_padded_helper([16, 9], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D16_3_2_cpu(self):
 | 
			
		||||
        self._test_packed_to_padded_helper([16, 3, 2], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_flat_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_packed_to_padded_helper(0, device)
 | 
			
		||||
        self._test_packed_to_padded_helper([], device)
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D1_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_packed_to_padded_helper(1, device)
 | 
			
		||||
        self._test_packed_to_padded_helper([1], device)
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D16_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_packed_to_padded_helper(16, device)
 | 
			
		||||
        self._test_packed_to_padded_helper([16], device)
 | 
			
		||||
 | 
			
		||||
    def _test_padded_to_packed_helper(self, D, device):
 | 
			
		||||
    def test_packed_to_padded_D16_9_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_packed_to_padded_helper([16, 9], device)
 | 
			
		||||
 | 
			
		||||
    def test_packed_to_padded_D16_3_2_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_packed_to_padded_helper([16, 3, 2], device)
 | 
			
		||||
 | 
			
		||||
    def _test_padded_to_packed_helper(self, dims, device):
 | 
			
		||||
        """
 | 
			
		||||
        Check the results from packed_to_padded and PyTorch implementations
 | 
			
		||||
        are the same.
 | 
			
		||||
@ -151,10 +167,10 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
 | 
			
		||||
        num_faces_per_mesh = meshes.num_faces_per_mesh()
 | 
			
		||||
        max_faces = num_faces_per_mesh.max().item()
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if len(dims) == 0:
 | 
			
		||||
            values = torch.rand((len(meshes), max_faces), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            values = torch.rand((len(meshes), max_faces, D), device=device)
 | 
			
		||||
            values = torch.rand((len(meshes), max_faces, *dims), device=device)
 | 
			
		||||
        for i, num in enumerate(num_faces_per_mesh):
 | 
			
		||||
            values[i, num:] = 0
 | 
			
		||||
        values.requires_grad = True
 | 
			
		||||
@ -173,11 +189,11 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(values_packed, values_packed_torch)
 | 
			
		||||
 | 
			
		||||
        # check backward
 | 
			
		||||
        if D == 0:
 | 
			
		||||
        if len(dims) == 0:
 | 
			
		||||
            grad_inputs = torch.rand((num_faces_per_mesh.sum().item()), device=device)
 | 
			
		||||
        else:
 | 
			
		||||
            grad_inputs = torch.rand(
 | 
			
		||||
                (num_faces_per_mesh.sum().item(), D), device=device
 | 
			
		||||
                (num_faces_per_mesh.sum().item(), *dims), device=device
 | 
			
		||||
            )
 | 
			
		||||
        values_packed.backward(grad_inputs)
 | 
			
		||||
        grad_outputs = values.grad
 | 
			
		||||
@ -190,41 +206,39 @@ class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(grad_outputs, grad_outputs_torch2)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_flat_cpu(self):
 | 
			
		||||
        self._test_padded_to_packed_helper(0, "cpu")
 | 
			
		||||
        self._test_padded_to_packed_helper([], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D1_cpu(self):
 | 
			
		||||
        self._test_padded_to_packed_helper(1, "cpu")
 | 
			
		||||
        self._test_padded_to_packed_helper([1], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D16_cpu(self):
 | 
			
		||||
        self._test_padded_to_packed_helper(16, "cpu")
 | 
			
		||||
        self._test_padded_to_packed_helper([16], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D16_9_cpu(self):
 | 
			
		||||
        self._test_padded_to_packed_helper([16, 9], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D16_3_2_cpu(self):
 | 
			
		||||
        self._test_padded_to_packed_helper([16, 3, 2], "cpu")
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_flat_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_padded_to_packed_helper(0, device)
 | 
			
		||||
        self._test_padded_to_packed_helper([], device)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D1_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_padded_to_packed_helper(1, device)
 | 
			
		||||
        self._test_padded_to_packed_helper([1], device)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_D16_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_padded_to_packed_helper(16, device)
 | 
			
		||||
        self._test_padded_to_packed_helper([16], device)
 | 
			
		||||
 | 
			
		||||
    def test_invalid_inputs_shapes(self, device="cuda:0"):
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "input can only be 2-dimensional."):
 | 
			
		||||
            values = torch.rand((100, 50, 2), device=device)
 | 
			
		||||
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
 | 
			
		||||
            packed_to_padded(values, first_idxs, 100)
 | 
			
		||||
    def test_padded_to_packed_D16_9_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_padded_to_packed_helper([16, 9], device)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."):
 | 
			
		||||
            values = torch.rand((100,), device=device)
 | 
			
		||||
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
 | 
			
		||||
            padded_to_packed(values, first_idxs, 20)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "input can only be 3-dimensional."):
 | 
			
		||||
            values = torch.rand((100, 50, 2, 2), device=device)
 | 
			
		||||
            first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
 | 
			
		||||
            padded_to_packed(values, first_idxs, 20)
 | 
			
		||||
    def test_padded_to_packed_D16_3_2_cuda(self):
 | 
			
		||||
        device = get_random_cuda_device()
 | 
			
		||||
        self._test_padded_to_packed_helper([16, 3, 2], device)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def packed_to_padded_with_init(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user