MC rasterize supports heterogeneous bundle; refactoring of bundle-to-padded

Summary:
Rasterize MC was not adapted to heterogeneous bundles.

There are some caveats though:
1) on CO3D, we get up to 18 points per image, which is too few for a reasonable visualisation (see below);
2) rasterising for a batch of 100 is slow.

I also moved the unpacking code close to the bundle to be able to reuse it.

{F789678778}

Reviewed By: bottler, davnov134

Differential Revision: D41008600

fbshipit-source-id: 9f10f1f9f9a174cf8c534b9b9859587d69832b71
This commit is contained in:
Roman Shapovalov
2022-11-07 13:43:31 -08:00
committed by Facebook GitHub Bot
parent 7be49bf46f
commit f3c1e0837c
10 changed files with 210 additions and 111 deletions

View File

@@ -60,7 +60,9 @@ class _PackedToPadded(Function):
return grad_input, None, None
def packed_to_padded(inputs, first_idxs, max_size):
def packed_to_padded(
inputs: torch.Tensor, first_idxs: torch.LongTensor, max_size: int
) -> torch.Tensor:
"""
Torch wrapper that handles allowed input shapes. See description below.
@@ -74,7 +76,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
Returns:
inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, ...)
where max_size is max of `sizes`. The values for batch element i
where max_size is max of `sizes`. The values for batch element i
which start at `inputs[first_idxs[i]]` will be copied to
`inputs_padded[i, :]`, with zeros padding out the extra inputs.
@@ -89,6 +91,7 @@ def packed_to_padded(inputs, first_idxs, max_size):
inputs = inputs.unsqueeze(1)
else:
inputs = inputs.reshape(input_shape[0], -1)
# pyre-ignore [16]
inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size)
# if flat is True, reshape output to (N, max_size) from (N, max_size, 1)
# else reshape output to (N, max_size, ...)
@@ -147,39 +150,49 @@ class _PaddedToPacked(Function):
return grad_input, None, None
def padded_to_packed(inputs, first_idxs, num_inputs):
def padded_to_packed(
inputs: torch.Tensor,
first_idxs: torch.LongTensor,
num_inputs: int,
max_size_dim: int = 1,
) -> torch.Tensor:
"""
Torch wrapper that handles allowed input shapes. See description below.
Args:
inputs: FloatTensor of shape (N, max_size) or (N, max_size, ...),
inputs: FloatTensor of shape (N, ..., max_size) or (N, ..., max_size, ...),
representing the padded tensor, e.g. areas for faces in a batch of
meshes.
meshes, where max_size occurs on max_size_dim-th position.
first_idxs: LongTensor of shape (N,) where N is the number of
elements in the batch and `first_idxs[i] = f`
means that the inputs for batch element i begin at `inputs_packed[f]`.
num_inputs: Number of packed entries (= F)
max_size_dim: the dimension to be packed
Returns:
inputs_packed: FloatTensor of shape (F,) or (F, ...) where
`inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, :]`.
`inputs_packed[first_idx[i]:first_idx[i+1]] = inputs[i, ..., :delta[i]]`,
where `delta[i] = first_idx[i+1] - first_idx[i]`.
To handle the allowed input shapes, we convert the inputs tensor of shape
(N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from
(N, max_size) to (N, max_size, 1). We reshape the output back to (F,) from
(F, 1).
"""
n_dims = inputs.dim()
# move the variable dim to position 1
inputs = inputs.movedim(max_size_dim, 1)
# if inputs is of shape (N, max_size), reshape into (N, max_size, 1))
input_shape = inputs.shape
n_dims = inputs.dim()
if n_dims == 2:
inputs = inputs.unsqueeze(2)
else:
inputs = inputs.reshape(*input_shape[:2], -1)
# pyre-ignore [16]
inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs)
# if input is flat, reshape output to (F,) from (F, 1)
# else reshape output to (F, ...)
if n_dims == 2:
return inputs_packed.squeeze(1)
if n_dims == 3:
return inputs_packed
return inputs_packed.view(-1, *input_shape[2:])