[pytorch3d[ padded to packed function in struct utils

Summary: Added a padded to packed utils function which takes either split sizes or a padding value to remove padded elements from a tensor.

Reviewed By: gkioxari

Differential Revision: D20454238

fbshipit-source-id: 180b807ff44c74c4ee9d5c1ac3b5c4a9b4be57c7
This commit is contained in:
Nikhila Ravi
2020-03-15 09:32:59 -07:00
committed by Facebook GitHub Bot
parent 4d3c886677
commit 20e457ca0e
2 changed files with 165 additions and 2 deletions

View File

@@ -97,6 +97,92 @@ class TestStructUtils(TestCaseMixin, unittest.TestCase):
split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size)
def test_padded_to_packed(self):
device = torch.device("cuda:0")
N = 5
K = 20
ndim = 2
dims = [K] * ndim
x = torch.rand([N] + dims, device=device)
# Case 1: no split_size or pad_value provided
# Check output is just the flattened input.
x_packed = struct_utils.padded_to_packed(x)
self.assertTrue(x_packed.shape == (x.shape[0] * x.shape[1], x.shape[2]))
self.assertClose(x_packed, x.reshape(-1, K))
# Case 2: pad_value is provided.
# Check each section of the packed tensor matches the
# corresponding unpadded elements of the padded tensor.
# Check that only rows where all the values are padded
# are removed in the conversion to packed.
pad_value = -1
x_list = []
split_size = []
for _ in range(N):
dim = torch.randint(K, size=(1,)).item()
# Add some random values in the input which are the same as the pad_value.
# These should not be filtered out.
x_list.append(
torch.randint(
low=pad_value, high=10, size=(dim, K), device=device
)
)
split_size.append(dim)
x_padded = struct_utils.list_to_padded(x_list, pad_value=pad_value)
x_packed = struct_utils.padded_to_packed(x_padded, pad_value=pad_value)
curr = 0
for i in range(N):
self.assertClose(
x_packed[curr : curr + split_size[i], ...], x_list[i]
)
self.assertClose(torch.cat(x_list), x_packed)
curr += split_size[i]
# Case 3: split_size is provided.
# Check each section of the packed tensor matches the corresponding
# unpadded elements.
x_packed = struct_utils.padded_to_packed(
x_padded, split_size=split_size
)
curr = 0
for i in range(N):
self.assertClose(
x_packed[curr : curr + split_size[i], ...], x_list[i]
)
self.assertClose(torch.cat(x_list), x_packed)
curr += split_size[i]
# Case 4: split_size of the wrong shape is provided.
# Raise an error.
split_size = torch.randint(1, K, size=(2 * N,)).view(N, 2).unbind(0)
with self.assertRaisesRegex(ValueError, "1-dimensional"):
x_packed = struct_utils.padded_to_packed(
x_padded, split_size=split_size
)
split_size = torch.randint(1, K, size=(2 * N,)).view(N * 2).tolist()
with self.assertRaisesRegex(
ValueError, "same length as inputs first dimension"
):
x_packed = struct_utils.padded_to_packed(
x_padded, split_size=split_size
)
# Case 5: both pad_value and split_size are provided.
# Raise an error.
with self.assertRaisesRegex(ValueError, "Only one of"):
x_packed = struct_utils.padded_to_packed(
x_padded, split_size=split_size, pad_value=-1
)
# Case 6: Input has more than 3 dims.
# Raise an error.
with self.assertRaisesRegex(ValueError, "Supports only"):
x = torch.rand((N, K, K, K, K), device=device)
split_size = torch.randint(1, K, size=(N,)).tolist()
struct_utils.padded_to_list(x, split_size)
def test_list_to_packed(self):
device = torch.device("cuda:0")
N = 5