pytorch3d/tests/test_packed_to_padded.py
Georgia Gkioxari 60f3c4e7d2 cpp support for packed to padded
Summary:
Cpu implementation for packed to padded and added gradients
```
Benchmark                                     Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
PACKED_TO_PADDED_2_100_300_1_cpu                    138             221           3625
PACKED_TO_PADDED_2_100_300_1_cuda:0                 184             261           2716
PACKED_TO_PADDED_2_100_300_16_cpu                   555             726            901
PACKED_TO_PADDED_2_100_300_16_cuda:0                179             260           2794
PACKED_TO_PADDED_2_100_3000_1_cpu                   396             519           1262
PACKED_TO_PADDED_2_100_3000_1_cuda:0                181             274           2764
PACKED_TO_PADDED_2_100_3000_16_cpu                 4517            5003            111
PACKED_TO_PADDED_2_100_3000_16_cuda:0               224             397           2235
PACKED_TO_PADDED_2_1000_300_1_cpu                   138             212           3616
PACKED_TO_PADDED_2_1000_300_1_cuda:0                180             282           2775
PACKED_TO_PADDED_2_1000_300_16_cpu                  565             711            885
PACKED_TO_PADDED_2_1000_300_16_cuda:0               179             264           2797
PACKED_TO_PADDED_2_1000_3000_1_cpu                  389             494           1287
PACKED_TO_PADDED_2_1000_3000_1_cuda:0               180             271           2777
PACKED_TO_PADDED_2_1000_3000_16_cpu                4522            5170            111
PACKED_TO_PADDED_2_1000_3000_16_cuda:0              216             286           2313
PACKED_TO_PADDED_10_100_300_1_cpu                   251             345           1995
PACKED_TO_PADDED_10_100_300_1_cuda:0                178             262           2806
PACKED_TO_PADDED_10_100_300_16_cpu                 2354            2750            213
PACKED_TO_PADDED_10_100_300_16_cuda:0               178             291           2814
PACKED_TO_PADDED_10_100_3000_1_cpu                 1519            1786            330
PACKED_TO_PADDED_10_100_3000_1_cuda:0               179             237           2791
PACKED_TO_PADDED_10_100_3000_16_cpu               24705           25879             21
PACKED_TO_PADDED_10_100_3000_16_cuda:0              228             316           2191
PACKED_TO_PADDED_10_1000_300_1_cpu                  261             432           1919
PACKED_TO_PADDED_10_1000_300_1_cuda:0               181             261           2756
PACKED_TO_PADDED_10_1000_300_16_cpu                2349            2770            213
PACKED_TO_PADDED_10_1000_300_16_cuda:0              180             256           2782
PACKED_TO_PADDED_10_1000_3000_1_cpu                1613            1929            310
PACKED_TO_PADDED_10_1000_3000_1_cuda:0              183             253           2739
PACKED_TO_PADDED_10_1000_3000_16_cpu              22041           23653             23
PACKED_TO_PADDED_10_1000_3000_16_cuda:0             220             343           2270
PACKED_TO_PADDED_32_100_300_1_cpu                   555             750            901
PACKED_TO_PADDED_32_100_300_1_cuda:0                188             282           2661
PACKED_TO_PADDED_32_100_300_16_cpu                 7550            8131             67
PACKED_TO_PADDED_32_100_300_16_cuda:0               181             272           2770
PACKED_TO_PADDED_32_100_3000_1_cpu                 4574            6327            110
PACKED_TO_PADDED_32_100_3000_1_cuda:0               173             254           2884
PACKED_TO_PADDED_32_100_3000_16_cpu               70366           72563              8
PACKED_TO_PADDED_32_100_3000_16_cuda:0              349             654           1433
PACKED_TO_PADDED_32_1000_300_1_cpu                  612             728            818
PACKED_TO_PADDED_32_1000_300_1_cuda:0               189             295           2647
PACKED_TO_PADDED_32_1000_300_16_cpu                7699            8254             65
PACKED_TO_PADDED_32_1000_300_16_cuda:0              189             311           2646
PACKED_TO_PADDED_32_1000_3000_1_cpu                5105            5261             98
PACKED_TO_PADDED_32_1000_3000_1_cuda:0              191             260           2625
PACKED_TO_PADDED_32_1000_3000_16_cpu              87073           92708              6
PACKED_TO_PADDED_32_1000_3000_16_cuda:0             344             425           1455
--------------------------------------------------------------------------------

Benchmark                                           Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
PACKED_TO_PADDED_TORCH_2_100_300_1_cpu                    492             627           1016
PACKED_TO_PADDED_TORCH_2_100_300_1_cuda:0                 768             975            652
PACKED_TO_PADDED_TORCH_2_100_300_16_cpu                   659             804            760
PACKED_TO_PADDED_TORCH_2_100_300_16_cuda:0                781             918            641
PACKED_TO_PADDED_TORCH_2_100_3000_1_cpu                   624             734            802
PACKED_TO_PADDED_TORCH_2_100_3000_1_cuda:0                778             929            643
PACKED_TO_PADDED_TORCH_2_100_3000_16_cpu                 2609            2850            192
PACKED_TO_PADDED_TORCH_2_100_3000_16_cuda:0               758             901            660
PACKED_TO_PADDED_TORCH_2_1000_300_1_cpu                   467             612           1072
PACKED_TO_PADDED_TORCH_2_1000_300_1_cuda:0                772             905            648
PACKED_TO_PADDED_TORCH_2_1000_300_16_cpu                  689             839            726
PACKED_TO_PADDED_TORCH_2_1000_300_16_cuda:0               789            1143            635
PACKED_TO_PADDED_TORCH_2_1000_3000_1_cpu                  629             735            795
PACKED_TO_PADDED_TORCH_2_1000_3000_1_cuda:0               812             916            616
PACKED_TO_PADDED_TORCH_2_1000_3000_16_cpu                2716            3117            185
PACKED_TO_PADDED_TORCH_2_1000_3000_16_cuda:0              844            1288            593
PACKED_TO_PADDED_TORCH_10_100_300_1_cpu                  2387            2557            210
PACKED_TO_PADDED_TORCH_10_100_300_1_cuda:0               4112            4993            122
PACKED_TO_PADDED_TORCH_10_100_300_16_cpu                 3385            4254            148
PACKED_TO_PADDED_TORCH_10_100_300_16_cuda:0              3959            4902            127
PACKED_TO_PADDED_TORCH_10_100_3000_1_cpu                 2918            3105            172
PACKED_TO_PADDED_TORCH_10_100_3000_1_cuda:0              4054            4450            124
PACKED_TO_PADDED_TORCH_10_100_3000_16_cpu               12748           13623             40
PACKED_TO_PADDED_TORCH_10_100_3000_16_cuda:0             4023            4395            125
PACKED_TO_PADDED_TORCH_10_1000_300_1_cpu                 2258            2492            222
PACKED_TO_PADDED_TORCH_10_1000_300_1_cuda:0              3997            4312            126
PACKED_TO_PADDED_TORCH_10_1000_300_16_cpu                3404            3597            147
PACKED_TO_PADDED_TORCH_10_1000_300_16_cuda:0             3877            4227            129
PACKED_TO_PADDED_TORCH_10_1000_3000_1_cpu                2789            3054            180
PACKED_TO_PADDED_TORCH_10_1000_3000_1_cuda:0             3821            4402            131
PACKED_TO_PADDED_TORCH_10_1000_3000_16_cpu              11967           12963             42
PACKED_TO_PADDED_TORCH_10_1000_3000_16_cuda:0            3729            4290            135
PACKED_TO_PADDED_TORCH_32_100_300_1_cpu                  6933            8152             73
PACKED_TO_PADDED_TORCH_32_100_300_1_cuda:0              11856           12287             43
PACKED_TO_PADDED_TORCH_32_100_300_16_cpu                 9895           11205             51
PACKED_TO_PADDED_TORCH_32_100_300_16_cuda:0             12354           13596             41
PACKED_TO_PADDED_TORCH_32_100_3000_1_cpu                 9516           10128             53
PACKED_TO_PADDED_TORCH_32_100_3000_1_cuda:0             12917           13597             39
PACKED_TO_PADDED_TORCH_32_100_3000_16_cpu               41209           43783             13
PACKED_TO_PADDED_TORCH_32_100_3000_16_cuda:0            12210           13288             41
PACKED_TO_PADDED_TORCH_32_1000_300_1_cpu                 7179            7689             70
PACKED_TO_PADDED_TORCH_32_1000_300_1_cuda:0             11896           12381             43
PACKED_TO_PADDED_TORCH_32_1000_300_16_cpu               10127           15494             50
PACKED_TO_PADDED_TORCH_32_1000_300_16_cuda:0            12034           12817             42
PACKED_TO_PADDED_TORCH_32_1000_3000_1_cpu                8743           10251             58
PACKED_TO_PADDED_TORCH_32_1000_3000_1_cuda:0            12023           12908             42
PACKED_TO_PADDED_TORCH_32_1000_3000_16_cpu              39071           41777             13
PACKED_TO_PADDED_TORCH_32_1000_3000_16_cuda:0           11999           13690             42
--------------------------------------------------------------------------------
```

Reviewed By: bottler, nikhilaravi, jcjohnson

Differential Revision: D19870575

fbshipit-source-id: 23a2477b73373c411899633386c87ab034c3702a
2020-02-19 10:48:54 -08:00

297 lines
10 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d.ops import packed_to_padded, padded_to_packed
from pytorch3d.structures.meshes import Meshes
from common_testing import TestCaseMixin
class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def init_meshes(
num_meshes: int = 10,
num_verts: int = 1000,
num_faces: int = 3000,
device: str = "cpu",
):
device = torch.device(device)
verts_list = []
faces_list = []
for _ in range(num_meshes):
verts = torch.rand(
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts_list, faces_list)
return meshes
@staticmethod
def packed_to_padded_python(inputs, first_idxs, max_size, device):
"""
PyTorch implementation of packed_to_padded function.
"""
num_meshes = first_idxs.size(0)
D = inputs.shape[1] if inputs.dim() == 2 else 0
if D == 0:
inputs_padded = torch.zeros((num_meshes, max_size), device=device)
else:
inputs_padded = torch.zeros(
(num_meshes, max_size, D), device=device
)
for m in range(num_meshes):
s = first_idxs[m]
if m == num_meshes - 1:
f = inputs.shape[0]
else:
f = first_idxs[m + 1]
inputs_padded[m, :f] = inputs[s:f]
return inputs_padded
@staticmethod
def padded_to_packed_python(inputs, first_idxs, num_inputs, device):
"""
PyTorch implementation of padded_to_packed function.
"""
num_meshes = inputs.size(0)
D = inputs.shape[2] if inputs.dim() == 3 else 0
if D == 0:
inputs_packed = torch.zeros((num_inputs,), device=device)
else:
inputs_packed = torch.zeros((num_inputs, D), device=device)
for m in range(num_meshes):
s = first_idxs[m]
if m == num_meshes - 1:
f = num_inputs
else:
f = first_idxs[m + 1]
inputs_packed[s:f] = inputs[m, :f]
return inputs_packed
def _test_packed_to_padded_helper(self, D, device):
"""
Check the results from packed_to_padded and PyTorch implementations
are the same.
"""
meshes = self.init_meshes(16, 100, 300, device=device)
faces = meshes.faces_packed()
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
if D == 0:
values = torch.rand(
(faces.shape[0],), device=device, requires_grad=True
)
else:
values = torch.rand(
(faces.shape[0], D), device=device, requires_grad=True
)
values_torch = values.detach().clone()
values_torch.requires_grad = True
values_padded = packed_to_padded(
values, mesh_to_faces_packed_first_idx, max_faces
)
values_padded_torch = TestPackedToPadded.packed_to_padded_python(
values_torch, mesh_to_faces_packed_first_idx, max_faces, device
)
# check forward
self.assertClose(values_padded, values_padded_torch)
# check backward
if D == 0:
grad_inputs = torch.rand((len(meshes), max_faces), device=device)
else:
grad_inputs = torch.rand((len(meshes), max_faces, D), device=device)
values_padded.backward(grad_inputs)
grad_outputs = values.grad
values_padded_torch.backward(grad_inputs)
grad_outputs_torch1 = values_torch.grad
grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python(
grad_inputs,
mesh_to_faces_packed_first_idx,
values.size(0),
device=device,
)
self.assertClose(grad_outputs, grad_outputs_torch1)
self.assertClose(grad_outputs, grad_outputs_torch2)
def test_packed_to_padded_flat_cpu(self):
self._test_packed_to_padded_helper(0, "cpu")
def test_packed_to_padded_D1_cpu(self):
self._test_packed_to_padded_helper(1, "cpu")
def test_packed_to_padded_D16_cpu(self):
self._test_packed_to_padded_helper(16, "cpu")
def test_packed_to_padded_flat_cuda(self):
self._test_packed_to_padded_helper(0, "cuda:0")
def test_packed_to_padded_D1_cuda(self):
self._test_packed_to_padded_helper(1, "cuda:0")
def test_packed_to_padded_D16_cuda(self):
self._test_packed_to_padded_helper(16, "cuda:0")
def _test_padded_to_packed_helper(self, D, device):
"""
Check the results from packed_to_padded and PyTorch implementations
are the same.
"""
meshes = self.init_meshes(16, 100, 300, device=device)
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
max_faces = num_faces_per_mesh.max().item()
if D == 0:
values = torch.rand((len(meshes), max_faces), device=device)
else:
values = torch.rand((len(meshes), max_faces, D), device=device)
for i, num in enumerate(num_faces_per_mesh):
values[i, num:] = 0
values.requires_grad = True
values_torch = values.detach().clone()
values_torch.requires_grad = True
values_packed = padded_to_packed(
values,
mesh_to_faces_packed_first_idx,
num_faces_per_mesh.sum().item(),
)
values_packed_torch = TestPackedToPadded.padded_to_packed_python(
values_torch,
mesh_to_faces_packed_first_idx,
num_faces_per_mesh.sum().item(),
device,
)
# check forward
self.assertClose(values_packed, values_packed_torch)
# check backward
if D == 0:
grad_inputs = torch.rand(
(num_faces_per_mesh.sum().item()), device=device
)
else:
grad_inputs = torch.rand(
(num_faces_per_mesh.sum().item(), D), device=device
)
values_packed.backward(grad_inputs)
grad_outputs = values.grad
values_packed_torch.backward(grad_inputs)
grad_outputs_torch1 = values_torch.grad
grad_outputs_torch2 = TestPackedToPadded.packed_to_padded_python(
grad_inputs,
mesh_to_faces_packed_first_idx,
values.size(1),
device=device,
)
self.assertClose(grad_outputs, grad_outputs_torch1)
self.assertClose(grad_outputs, grad_outputs_torch2)
def test_padded_to_packed_flat_cpu(self):
self._test_padded_to_packed_helper(0, "cpu")
def test_padded_to_packed_D1_cpu(self):
self._test_padded_to_packed_helper(1, "cpu")
def test_padded_to_packed_D16_cpu(self):
self._test_padded_to_packed_helper(16, "cpu")
def test_padded_to_packed_flat_cuda(self):
self._test_padded_to_packed_helper(0, "cuda:0")
def test_padded_to_packed_D1_cuda(self):
self._test_padded_to_packed_helper(1, "cuda:0")
def test_padded_to_packed_D16_cuda(self):
self._test_padded_to_packed_helper(16, "cuda:0")
def test_invalid_inputs_shapes(self, device="cuda:0"):
with self.assertRaisesRegex(
ValueError, "input can only be 2-dimensional."
):
values = torch.rand((100, 50, 2), device=device)
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
packed_to_padded(values, first_idxs, 100)
with self.assertRaisesRegex(
ValueError, "input can only be 3-dimensional."
):
values = torch.rand((100,), device=device)
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
padded_to_packed(values, first_idxs, 20)
with self.assertRaisesRegex(
ValueError, "input can only be 3-dimensional."
):
values = torch.rand((100, 50, 2, 2), device=device)
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
padded_to_packed(values, first_idxs, 20)
@staticmethod
def packed_to_padded_with_init(
num_meshes: int,
num_verts: int,
num_faces: int,
num_d: int,
device: str = "cpu",
):
meshes = TestPackedToPadded.init_meshes(
num_meshes, num_verts, num_faces, device
)
faces = meshes.faces_packed()
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
if num_d == 0:
values = torch.rand((faces.shape[0],), device=meshes.device)
else:
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
torch.cuda.synchronize()
def out():
packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces)
torch.cuda.synchronize()
return out
@staticmethod
def packed_to_padded_with_init_torch(
num_meshes: int,
num_verts: int,
num_faces: int,
num_d: int,
device: str = "cpu",
):
meshes = TestPackedToPadded.init_meshes(
num_meshes, num_verts, num_faces, device
)
faces = meshes.faces_packed()
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
if num_d == 0:
values = torch.rand((faces.shape[0],), device=meshes.device)
else:
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
torch.cuda.synchronize()
def out():
TestPackedToPadded.packed_to_padded_python(
values, mesh_to_faces_packed_first_idx, max_faces, device
)
torch.cuda.synchronize()
return out