mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
cpp support for packed to padded
Summary: Cpu implementation for packed to padded and added gradients ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- PACKED_TO_PADDED_2_100_300_1_cpu 138 221 3625 PACKED_TO_PADDED_2_100_300_1_cuda:0 184 261 2716 PACKED_TO_PADDED_2_100_300_16_cpu 555 726 901 PACKED_TO_PADDED_2_100_300_16_cuda:0 179 260 2794 PACKED_TO_PADDED_2_100_3000_1_cpu 396 519 1262 PACKED_TO_PADDED_2_100_3000_1_cuda:0 181 274 2764 PACKED_TO_PADDED_2_100_3000_16_cpu 4517 5003 111 PACKED_TO_PADDED_2_100_3000_16_cuda:0 224 397 2235 PACKED_TO_PADDED_2_1000_300_1_cpu 138 212 3616 PACKED_TO_PADDED_2_1000_300_1_cuda:0 180 282 2775 PACKED_TO_PADDED_2_1000_300_16_cpu 565 711 885 PACKED_TO_PADDED_2_1000_300_16_cuda:0 179 264 2797 PACKED_TO_PADDED_2_1000_3000_1_cpu 389 494 1287 PACKED_TO_PADDED_2_1000_3000_1_cuda:0 180 271 2777 PACKED_TO_PADDED_2_1000_3000_16_cpu 4522 5170 111 PACKED_TO_PADDED_2_1000_3000_16_cuda:0 216 286 2313 PACKED_TO_PADDED_10_100_300_1_cpu 251 345 1995 PACKED_TO_PADDED_10_100_300_1_cuda:0 178 262 2806 PACKED_TO_PADDED_10_100_300_16_cpu 2354 2750 213 PACKED_TO_PADDED_10_100_300_16_cuda:0 178 291 2814 PACKED_TO_PADDED_10_100_3000_1_cpu 1519 1786 330 PACKED_TO_PADDED_10_100_3000_1_cuda:0 179 237 2791 PACKED_TO_PADDED_10_100_3000_16_cpu 24705 25879 21 PACKED_TO_PADDED_10_100_3000_16_cuda:0 228 316 2191 PACKED_TO_PADDED_10_1000_300_1_cpu 261 432 1919 PACKED_TO_PADDED_10_1000_300_1_cuda:0 181 261 2756 PACKED_TO_PADDED_10_1000_300_16_cpu 2349 2770 213 PACKED_TO_PADDED_10_1000_300_16_cuda:0 180 256 2782 PACKED_TO_PADDED_10_1000_3000_1_cpu 1613 1929 310 PACKED_TO_PADDED_10_1000_3000_1_cuda:0 183 253 2739 PACKED_TO_PADDED_10_1000_3000_16_cpu 22041 23653 23 PACKED_TO_PADDED_10_1000_3000_16_cuda:0 220 343 2270 PACKED_TO_PADDED_32_100_300_1_cpu 555 750 901 PACKED_TO_PADDED_32_100_300_1_cuda:0 188 282 2661 PACKED_TO_PADDED_32_100_300_16_cpu 7550 8131 67 PACKED_TO_PADDED_32_100_300_16_cuda:0 181 272 2770 PACKED_TO_PADDED_32_100_3000_1_cpu 4574 6327 110 PACKED_TO_PADDED_32_100_3000_1_cuda:0 173 254 2884 PACKED_TO_PADDED_32_100_3000_16_cpu 70366 72563 8 PACKED_TO_PADDED_32_100_3000_16_cuda:0 349 654 1433 PACKED_TO_PADDED_32_1000_300_1_cpu 612 728 818 PACKED_TO_PADDED_32_1000_300_1_cuda:0 189 295 2647 PACKED_TO_PADDED_32_1000_300_16_cpu 7699 8254 65 PACKED_TO_PADDED_32_1000_300_16_cuda:0 189 311 2646 PACKED_TO_PADDED_32_1000_3000_1_cpu 5105 5261 98 PACKED_TO_PADDED_32_1000_3000_1_cuda:0 191 260 2625 PACKED_TO_PADDED_32_1000_3000_16_cpu 87073 92708 6 PACKED_TO_PADDED_32_1000_3000_16_cuda:0 344 425 1455 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- PACKED_TO_PADDED_TORCH_2_100_300_1_cpu 492 627 1016 PACKED_TO_PADDED_TORCH_2_100_300_1_cuda:0 768 975 652 PACKED_TO_PADDED_TORCH_2_100_300_16_cpu 659 804 760 PACKED_TO_PADDED_TORCH_2_100_300_16_cuda:0 781 918 641 PACKED_TO_PADDED_TORCH_2_100_3000_1_cpu 624 734 802 PACKED_TO_PADDED_TORCH_2_100_3000_1_cuda:0 778 929 643 PACKED_TO_PADDED_TORCH_2_100_3000_16_cpu 2609 2850 192 PACKED_TO_PADDED_TORCH_2_100_3000_16_cuda:0 758 901 660 PACKED_TO_PADDED_TORCH_2_1000_300_1_cpu 467 612 1072 PACKED_TO_PADDED_TORCH_2_1000_300_1_cuda:0 772 905 648 PACKED_TO_PADDED_TORCH_2_1000_300_16_cpu 689 839 726 PACKED_TO_PADDED_TORCH_2_1000_300_16_cuda:0 789 1143 635 PACKED_TO_PADDED_TORCH_2_1000_3000_1_cpu 629 735 795 PACKED_TO_PADDED_TORCH_2_1000_3000_1_cuda:0 812 916 616 PACKED_TO_PADDED_TORCH_2_1000_3000_16_cpu 2716 3117 185 PACKED_TO_PADDED_TORCH_2_1000_3000_16_cuda:0 844 1288 593 PACKED_TO_PADDED_TORCH_10_100_300_1_cpu 2387 2557 210 PACKED_TO_PADDED_TORCH_10_100_300_1_cuda:0 4112 4993 122 PACKED_TO_PADDED_TORCH_10_100_300_16_cpu 3385 4254 148 PACKED_TO_PADDED_TORCH_10_100_300_16_cuda:0 3959 4902 127 PACKED_TO_PADDED_TORCH_10_100_3000_1_cpu 2918 3105 172 PACKED_TO_PADDED_TORCH_10_100_3000_1_cuda:0 4054 4450 124 PACKED_TO_PADDED_TORCH_10_100_3000_16_cpu 12748 13623 40 PACKED_TO_PADDED_TORCH_10_100_3000_16_cuda:0 4023 4395 125 PACKED_TO_PADDED_TORCH_10_1000_300_1_cpu 2258 2492 222 PACKED_TO_PADDED_TORCH_10_1000_300_1_cuda:0 3997 4312 126 PACKED_TO_PADDED_TORCH_10_1000_300_16_cpu 3404 3597 147 PACKED_TO_PADDED_TORCH_10_1000_300_16_cuda:0 3877 4227 129 PACKED_TO_PADDED_TORCH_10_1000_3000_1_cpu 2789 3054 180 PACKED_TO_PADDED_TORCH_10_1000_3000_1_cuda:0 3821 4402 131 PACKED_TO_PADDED_TORCH_10_1000_3000_16_cpu 11967 12963 42 PACKED_TO_PADDED_TORCH_10_1000_3000_16_cuda:0 3729 4290 135 PACKED_TO_PADDED_TORCH_32_100_300_1_cpu 6933 8152 73 PACKED_TO_PADDED_TORCH_32_100_300_1_cuda:0 11856 12287 43 PACKED_TO_PADDED_TORCH_32_100_300_16_cpu 9895 11205 51 PACKED_TO_PADDED_TORCH_32_100_300_16_cuda:0 12354 13596 41 PACKED_TO_PADDED_TORCH_32_100_3000_1_cpu 9516 10128 53 PACKED_TO_PADDED_TORCH_32_100_3000_1_cuda:0 12917 13597 39 PACKED_TO_PADDED_TORCH_32_100_3000_16_cpu 41209 43783 13 PACKED_TO_PADDED_TORCH_32_100_3000_16_cuda:0 12210 13288 41 PACKED_TO_PADDED_TORCH_32_1000_300_1_cpu 7179 7689 70 PACKED_TO_PADDED_TORCH_32_1000_300_1_cuda:0 11896 12381 43 PACKED_TO_PADDED_TORCH_32_1000_300_16_cpu 10127 15494 50 PACKED_TO_PADDED_TORCH_32_1000_300_16_cuda:0 12034 12817 42 PACKED_TO_PADDED_TORCH_32_1000_3000_1_cpu 8743 10251 58 PACKED_TO_PADDED_TORCH_32_1000_3000_1_cuda:0 12023 12908 42 PACKED_TO_PADDED_TORCH_32_1000_3000_16_cpu 39071 41777 13 PACKED_TO_PADDED_TORCH_32_1000_3000_16_cuda:0 11999 13690 42 -------------------------------------------------------------------------------- ``` Reviewed By: bottler, nikhilaravi, jcjohnson Differential Revision: D19870575 fbshipit-source-id: 23a2477b73373c411899633386c87ab034c3702a
This commit is contained in:
committed by
Facebook Github Bot
parent
8301163d24
commit
60f3c4e7d2
47
tests/bm_packed_to_padded.py
Normal file
47
tests/bm_packed_to_padded.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from itertools import product
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from test_packed_to_padded import TestPackedToPadded
|
||||
|
||||
|
||||
def bm_packed_to_padded() -> None:
|
||||
kwargs_list = []
|
||||
backend = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
backend.append("cuda:0")
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
num_ds = [0, 1, 16]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_ds, backend)
|
||||
for case in test_cases:
|
||||
n, v, f, d, b = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_d": d,
|
||||
"device": b,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestPackedToPadded.packed_to_padded_with_init,
|
||||
"PACKED_TO_PADDED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestPackedToPadded.packed_to_padded_with_init_torch,
|
||||
"PACKED_TO_PADDED_TORCH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
@@ -10,56 +10,30 @@ from test_sample_points_from_meshes import TestSamplePoints
|
||||
|
||||
|
||||
def bm_sample_points() -> None:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
kwargs_list = []
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
num_samples = [5000, 10000]
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_samples)
|
||||
for case in test_cases:
|
||||
n, v, f, s = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_samples": s,
|
||||
"device": device,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.sample_points_with_init,
|
||||
"SAMPLE_MESH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
backend = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
backend.append("cuda:0")
|
||||
kwargs_list = []
|
||||
backend_cuda = ["False"]
|
||||
if torch.cuda.is_available():
|
||||
backend_cuda.append("True")
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
|
||||
num_samples = [5000, 10000]
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_samples, backend)
|
||||
for case in test_cases:
|
||||
n, v, f, c = case
|
||||
n, v, f, s, b = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_samples": s,
|
||||
"device": b,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.face_areas_with_init,
|
||||
"FACE_AREAS",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.packed_to_padded_with_init,
|
||||
"PACKED_TO_PADDED",
|
||||
TestSamplePoints.sample_points_with_init,
|
||||
"SAMPLE_MESH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
296
tests/test_packed_to_padded.py
Normal file
296
tests/test_packed_to_padded.py
Normal file
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.ops import packed_to_padded, padded_to_packed
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def init_meshes(
|
||||
num_meshes: int = 10,
|
||||
num_verts: int = 1000,
|
||||
num_faces: int = 3000,
|
||||
device: str = "cpu",
|
||||
):
|
||||
device = torch.device(device)
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = torch.rand(
|
||||
(num_verts, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
faces = torch.randint(
|
||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||
)
|
||||
verts_list.append(verts)
|
||||
faces_list.append(faces)
|
||||
meshes = Meshes(verts_list, faces_list)
|
||||
|
||||
return meshes
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_python(inputs, first_idxs, max_size, device):
|
||||
"""
|
||||
PyTorch implementation of packed_to_padded function.
|
||||
"""
|
||||
num_meshes = first_idxs.size(0)
|
||||
D = inputs.shape[1] if inputs.dim() == 2 else 0
|
||||
if D == 0:
|
||||
inputs_padded = torch.zeros((num_meshes, max_size), device=device)
|
||||
else:
|
||||
inputs_padded = torch.zeros(
|
||||
(num_meshes, max_size, D), device=device
|
||||
)
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = inputs.shape[0]
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_padded[m, :f] = inputs[s:f]
|
||||
|
||||
return inputs_padded
|
||||
|
||||
@staticmethod
|
||||
def padded_to_packed_python(inputs, first_idxs, num_inputs, device):
|
||||
"""
|
||||
PyTorch implementation of padded_to_packed function.
|
||||
"""
|
||||
num_meshes = inputs.size(0)
|
||||
D = inputs.shape[2] if inputs.dim() == 3 else 0
|
||||
if D == 0:
|
||||
inputs_packed = torch.zeros((num_inputs,), device=device)
|
||||
else:
|
||||
inputs_packed = torch.zeros((num_inputs, D), device=device)
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = num_inputs
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_packed[s:f] = inputs[m, :f]
|
||||
|
||||
return inputs_packed
|
||||
|
||||
def _test_packed_to_padded_helper(self, D, device):
|
||||
"""
|
||||
Check the results from packed_to_padded and PyTorch implementations
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(16, 100, 300, device=device)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
if D == 0:
|
||||
values = torch.rand(
|
||||
(faces.shape[0],), device=device, requires_grad=True
|
||||
)
|
||||
else:
|
||||
values = torch.rand(
|
||||
(faces.shape[0], D), device=device, requires_grad=True
|
||||
)
|
||||
values_torch = values.detach().clone()
|
||||
values_torch.requires_grad = True
|
||||
values_padded = packed_to_padded(
|
||||
values, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
values_padded_torch = TestPackedToPadded.packed_to_padded_python(
|
||||
values_torch, mesh_to_faces_packed_first_idx, max_faces, device
|
||||
)
|
||||
# check forward
|
||||
self.assertClose(values_padded, values_padded_torch)
|
||||
|
||||
# check backward
|
||||
if D == 0:
|
||||
grad_inputs = torch.rand((len(meshes), max_faces), device=device)
|
||||
else:
|
||||
grad_inputs = torch.rand((len(meshes), max_faces, D), device=device)
|
||||
values_padded.backward(grad_inputs)
|
||||
grad_outputs = values.grad
|
||||
values_padded_torch.backward(grad_inputs)
|
||||
grad_outputs_torch1 = values_torch.grad
|
||||
grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python(
|
||||
grad_inputs,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
values.size(0),
|
||||
device=device,
|
||||
)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch1)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch2)
|
||||
|
||||
def test_packed_to_padded_flat_cpu(self):
|
||||
self._test_packed_to_padded_helper(0, "cpu")
|
||||
|
||||
def test_packed_to_padded_D1_cpu(self):
|
||||
self._test_packed_to_padded_helper(1, "cpu")
|
||||
|
||||
def test_packed_to_padded_D16_cpu(self):
|
||||
self._test_packed_to_padded_helper(16, "cpu")
|
||||
|
||||
def test_packed_to_padded_flat_cuda(self):
|
||||
self._test_packed_to_padded_helper(0, "cuda:0")
|
||||
|
||||
def test_packed_to_padded_D1_cuda(self):
|
||||
self._test_packed_to_padded_helper(1, "cuda:0")
|
||||
|
||||
def test_packed_to_padded_D16_cuda(self):
|
||||
self._test_packed_to_padded_helper(16, "cuda:0")
|
||||
|
||||
def _test_padded_to_packed_helper(self, D, device):
|
||||
"""
|
||||
Check the results from packed_to_padded and PyTorch implementations
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(16, 100, 300, device=device)
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||
max_faces = num_faces_per_mesh.max().item()
|
||||
if D == 0:
|
||||
values = torch.rand((len(meshes), max_faces), device=device)
|
||||
else:
|
||||
values = torch.rand((len(meshes), max_faces, D), device=device)
|
||||
for i, num in enumerate(num_faces_per_mesh):
|
||||
values[i, num:] = 0
|
||||
values.requires_grad = True
|
||||
values_torch = values.detach().clone()
|
||||
values_torch.requires_grad = True
|
||||
values_packed = padded_to_packed(
|
||||
values,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
num_faces_per_mesh.sum().item(),
|
||||
)
|
||||
values_packed_torch = TestPackedToPadded.padded_to_packed_python(
|
||||
values_torch,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
num_faces_per_mesh.sum().item(),
|
||||
device,
|
||||
)
|
||||
# check forward
|
||||
self.assertClose(values_packed, values_packed_torch)
|
||||
|
||||
# check backward
|
||||
if D == 0:
|
||||
grad_inputs = torch.rand(
|
||||
(num_faces_per_mesh.sum().item()), device=device
|
||||
)
|
||||
else:
|
||||
grad_inputs = torch.rand(
|
||||
(num_faces_per_mesh.sum().item(), D), device=device
|
||||
)
|
||||
values_packed.backward(grad_inputs)
|
||||
grad_outputs = values.grad
|
||||
values_packed_torch.backward(grad_inputs)
|
||||
grad_outputs_torch1 = values_torch.grad
|
||||
grad_outputs_torch2 = TestPackedToPadded.packed_to_padded_python(
|
||||
grad_inputs,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
values.size(1),
|
||||
device=device,
|
||||
)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch1)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch2)
|
||||
|
||||
def test_padded_to_packed_flat_cpu(self):
|
||||
self._test_padded_to_packed_helper(0, "cpu")
|
||||
|
||||
def test_padded_to_packed_D1_cpu(self):
|
||||
self._test_padded_to_packed_helper(1, "cpu")
|
||||
|
||||
def test_padded_to_packed_D16_cpu(self):
|
||||
self._test_padded_to_packed_helper(16, "cpu")
|
||||
|
||||
def test_padded_to_packed_flat_cuda(self):
|
||||
self._test_padded_to_packed_helper(0, "cuda:0")
|
||||
|
||||
def test_padded_to_packed_D1_cuda(self):
|
||||
self._test_padded_to_packed_helper(1, "cuda:0")
|
||||
|
||||
def test_padded_to_packed_D16_cuda(self):
|
||||
self._test_padded_to_packed_helper(16, "cuda:0")
|
||||
|
||||
def test_invalid_inputs_shapes(self, device="cuda:0"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 2-dimensional."
|
||||
):
|
||||
values = torch.rand((100, 50, 2), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
packed_to_padded(values, first_idxs, 100)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 3-dimensional."
|
||||
):
|
||||
values = torch.rand((100,), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
padded_to_packed(values, first_idxs, 20)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 3-dimensional."
|
||||
):
|
||||
values = torch.rand((100, 50, 2, 2), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
padded_to_packed(values, first_idxs, 20)
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int,
|
||||
num_verts: int,
|
||||
num_faces: int,
|
||||
num_d: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
meshes = TestPackedToPadded.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
if num_d == 0:
|
||||
values = torch.rand((faces.shape[0],), device=meshes.device)
|
||||
else:
|
||||
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def out():
|
||||
packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init_torch(
|
||||
num_meshes: int,
|
||||
num_verts: int,
|
||||
num_faces: int,
|
||||
num_d: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
meshes = TestPackedToPadded.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
if num_d == 0:
|
||||
values = torch.rand((faces.shape[0],), device=meshes.device)
|
||||
else:
|
||||
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def out():
|
||||
TestPackedToPadded.packed_to_padded_python(
|
||||
values, mesh_to_faces_packed_first_idx, max_faces, device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return out
|
||||
@@ -294,48 +294,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_tensor(inputs, first_idxs, max_size):
|
||||
"""
|
||||
PyTorch implementation of cuda packed_to_padded_tensor function.
|
||||
"""
|
||||
num_meshes = first_idxs.size(0)
|
||||
inputs_padded = torch.zeros((num_meshes, max_size))
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = inputs.size(0)
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_padded[m, :f] = inputs[s:f]
|
||||
|
||||
return inputs_padded
|
||||
|
||||
def test_packed_to_padded_tensor(self):
|
||||
"""
|
||||
Check the results from packed_to_padded cuda and PyTorch implementions
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(1, 3, 5, device="cuda:0")
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
areas_padded = _C.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
).cpu()
|
||||
areas_padded_cpu = TestSamplePoints.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
self.assertTrue(torch.allclose(areas_padded, areas_padded_cpu))
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.packed_to_padded_tensor(
|
||||
areas.cpu(), mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
|
||||
@staticmethod
|
||||
def sample_points_with_init(
|
||||
num_meshes: int,
|
||||
@@ -344,7 +302,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
num_samples: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
device = torch.device(device)
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
@@ -366,32 +323,3 @@ class TestSamplePoints(unittest.TestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return sample_points
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
):
|
||||
device = "cuda" if cuda else "cpu"
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def packed_to_padded():
|
||||
if cuda:
|
||||
_C.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
else:
|
||||
TestSamplePoints.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return packed_to_padded
|
||||
|
||||
Reference in New Issue
Block a user