mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
[pytorch3d[ padded to packed function in struct utils
Summary: Added a padded to packed utils function which takes either split sizes or a padding value to remove padded elements from a tensor. Reviewed By: gkioxari Differential Revision: D20454238 fbshipit-source-id: 180b807ff44c74c4ee9d5c1ac3b5c4a9b4be57c7
This commit is contained in:
parent
4d3c886677
commit
20e457ca0e
@ -66,11 +66,16 @@ def padded_to_list(
|
||||
|
||||
Args:
|
||||
x: tensor
|
||||
split_size: the shape of the final tensor to be returned (of length N).
|
||||
split_size: list, tuple or int defining the number of items for each tensor
|
||||
in the output list.
|
||||
|
||||
Returns:
|
||||
x_list: a list of tensors
|
||||
"""
|
||||
if x.ndim != 3:
|
||||
raise ValueError("Supports only 3-dimensional input tensors")
|
||||
x_list = list(x.unbind(0))
|
||||
|
||||
if split_size is None:
|
||||
return x_list
|
||||
|
||||
@ -141,9 +146,81 @@ def packed_to_list(x: torch.Tensor, split_size: Union[list, int]):
|
||||
|
||||
Args:
|
||||
x: tensor
|
||||
split_size: list or int defining the number of items for each split
|
||||
split_size: list, tuple or int defining the number of items for each tensor
|
||||
in the output list.
|
||||
|
||||
Returns:
|
||||
x_list: A list of Tensors
|
||||
"""
|
||||
return x.split(split_size, dim=0)
|
||||
|
||||
|
||||
def padded_to_packed(
|
||||
x: torch.Tensor,
|
||||
split_size: Union[list, tuple, None] = None,
|
||||
pad_value: Union[float, int, None] = None,
|
||||
):
|
||||
r"""
|
||||
Transforms a padded tensor of shape (N, M, K) into a packed tensor
|
||||
of shape:
|
||||
- (sum(Mi), K) where (Mi, K) are the dimensions of
|
||||
each of the tensors in the batch and Mi is specified by split_size(i)
|
||||
- (N*M, K) if split_size is None
|
||||
|
||||
Support only for 3-dimensional input tensor and 1-dimensional split size.
|
||||
|
||||
Args:
|
||||
x: tensor
|
||||
split_size: list, tuple or int defining the number of items for each tensor
|
||||
in the output list.
|
||||
pad_value: optional value to use to filter the padded values in the input
|
||||
tensor.
|
||||
|
||||
Only one of split_size or pad_value should be provided, or both can be None.
|
||||
|
||||
Returns:
|
||||
x_packed: a packed tensor.
|
||||
"""
|
||||
if x.ndim != 3:
|
||||
raise ValueError("Supports only 3-dimensional input tensors")
|
||||
|
||||
N, M, D = x.shape
|
||||
|
||||
if split_size is not None and pad_value is not None:
|
||||
raise ValueError(
|
||||
"Only one of split_size or pad_value should be provided."
|
||||
)
|
||||
|
||||
x_packed = x.view(-1, D) # flatten padded
|
||||
|
||||
if pad_value is None and split_size is None:
|
||||
return x_packed
|
||||
|
||||
# Convert to packed using pad value
|
||||
if pad_value is not None:
|
||||
mask = x_packed.ne(pad_value).any(-1)
|
||||
x_packed = x_packed[mask]
|
||||
return x_packed
|
||||
|
||||
# Convert to packed using split sizes
|
||||
N = len(split_size)
|
||||
if x.shape[0] != N:
|
||||
raise ValueError(
|
||||
"Split size must be of same length as inputs first dimension"
|
||||
)
|
||||
|
||||
if not all(isinstance(i, int) for i in split_size):
|
||||
raise ValueError(
|
||||
"Support only 1-dimensional unbinded tensor. \
|
||||
Split size for more dimensions provided"
|
||||
)
|
||||
|
||||
padded_to_packed_idx = torch.cat(
|
||||
[
|
||||
torch.arange(v, dtype=torch.int64, device=x.device) + i * M
|
||||
for (i, v) in enumerate(split_size)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return x_packed[padded_to_packed_idx]
|
||||
|
@ -97,6 +97,92 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
|
||||
split_size = torch.randint(1, K, size=(N,)).tolist()
|
||||
struct_utils.padded_to_list(x, split_size)
|
||||
|
||||
def test_padded_to_packed(self):
|
||||
device = torch.device("cuda:0")
|
||||
N = 5
|
||||
K = 20
|
||||
ndim = 2
|
||||
dims = [K] * ndim
|
||||
x = torch.rand([N] + dims, device=device)
|
||||
|
||||
# Case 1: no split_size or pad_value provided
|
||||
# Check output is just the flattened input.
|
||||
x_packed = struct_utils.padded_to_packed(x)
|
||||
self.assertTrue(x_packed.shape == (x.shape[0] * x.shape[1], x.shape[2]))
|
||||
self.assertClose(x_packed, x.reshape(-1, K))
|
||||
|
||||
# Case 2: pad_value is provided.
|
||||
# Check each section of the packed tensor matches the
|
||||
# corresponding unpadded elements of the padded tensor.
|
||||
# Check that only rows where all the values are padded
|
||||
# are removed in the conversion to packed.
|
||||
pad_value = -1
|
||||
x_list = []
|
||||
split_size = []
|
||||
for _ in range(N):
|
||||
dim = torch.randint(K, size=(1,)).item()
|
||||
# Add some random values in the input which are the same as the pad_value.
|
||||
# These should not be filtered out.
|
||||
x_list.append(
|
||||
torch.randint(
|
||||
low=pad_value, high=10, size=(dim, K), device=device
|
||||
)
|
||||
)
|
||||
split_size.append(dim)
|
||||
x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value)
|
||||
x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value)
|
||||
curr = 0
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
x_packed[curr : curr + split_size[i], ...], x_list[i]
|
||||
)
|
||||
self.assertClose(torch.cat(x_list), x_packed)
|
||||
curr += split_size[i]
|
||||
|
||||
# Case 3: split_size is provided.
|
||||
# Check each section of the packed tensor matches the corresponding
|
||||
# unpadded elements.
|
||||
x_packed = struct_utils.padded_to_packed(
|
||||
x_padded, split_size=split_size
|
||||
)
|
||||
curr = 0
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
x_packed[curr : curr + split_size[i], ...], x_list[i]
|
||||
)
|
||||
self.assertClose(torch.cat(x_list), x_packed)
|
||||
curr += split_size[i]
|
||||
|
||||
# Case 4: split_size of the wrong shape is provided.
|
||||
# Raise an error.
|
||||
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
|
||||
with self.assertRaisesRegex(ValueError, "1-dimensional"):
|
||||
x_packed = struct_utils.padded_to_packed(
|
||||
x_padded, split_size=split_size
|
||||
)
|
||||
|
||||
split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist()
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "same length as inputs first dimension"
|
||||
):
|
||||
x_packed = struct_utils.padded_to_packed(
|
||||
x_padded, split_size=split_size
|
||||
)
|
||||
|
||||
# Case 5: both pad_value and split_size are provided.
|
||||
# Raise an error.
|
||||
with self.assertRaisesRegex(ValueError, "Only one of"):
|
||||
x_packed = struct_utils.padded_to_packed(
|
||||
x_padded, split_size=split_size, pad_value=-1
|
||||
)
|
||||
|
||||
# Case 6: Input has more than 3 dims.
|
||||
# Raise an error.
|
||||
with self.assertRaisesRegex(ValueError, "Supports only"):
|
||||
x = torch.rand((N, K, K, K, K), device=device)
|
||||
split_size = torch.randint(1, K, size=(N,)).tolist()
|
||||
struct_utils.padded_to_list(x, split_size)
|
||||
|
||||
def test_list_to_packed(self):
|
||||
device = torch.device("cuda:0")
|
||||
N = 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user