mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
cpu benchmarks for points to volumes
Summary: Add a CPU version to the benchmarks. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_1000 10100 46422 50 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_10000 28442 32100 18 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[25, 25, 25]_100000 241127 254269 3 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_1000 54149 79480 10 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_10000 125459 212734 4 ADD_POINTS_TO_VOLUMES_cpu_10_trilinear_[101, 111, 121]_100000 512739 512739 1 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_1000 2866 13365 175 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_10000 7026 12604 72 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[25, 25, 25]_100000 48822 55607 11 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_1000 38098 38576 14 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_10000 48006 54120 11 ADD_POINTS_TO_VOLUMES_cpu_10_nearest_[101, 111, 121]_100000 131563 138536 4 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_1000 64615 91735 8 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_10000 228815 246095 3 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[25, 25, 25]_100000 3086615 3086615 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_1000 464298 465292 2 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_10000 1053440 1053440 1 ADD_POINTS_TO_VOLUMES_cpu_100_trilinear_[101, 111, 121]_100000 6736236 6736236 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_1000 11940 12440 42 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_10000 56641 58051 9 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[25, 25, 25]_100000 711492 711492 1 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_1000 326437 329846 2 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_10000 418514 427911 2 ADD_POINTS_TO_VOLUMES_cpu_100_nearest_[101, 111, 121]_100000 1524285 1524285 1 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_1000 5949 13602 85 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_10000 5817 13001 86 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[25, 25, 25]_100000 23833 25971 21 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_1000 9029 16178 56 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_10000 11595 18601 44 ADD_POINTS_TO_VOLUMES_cuda:0_10_trilinear_[101, 111, 121]_100000 46986 47344 11 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_1000 2554 9747 196 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_10000 2676 9537 187 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[25, 25, 25]_100000 6567 14179 77 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_1000 5840 12811 86 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_10000 6102 13128 82 ADD_POINTS_TO_VOLUMES_cuda:0_10_nearest_[101, 111, 121]_100000 11945 11995 42 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_1000 7642 13671 66 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_10000 25190 25260 20 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[25, 25, 25]_100000 212018 212134 3 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_1000 40421 45692 13 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_10000 92078 92132 6 ADD_POINTS_TO_VOLUMES_cuda:0_100_trilinear_[101, 111, 121]_100000 457211 457229 2 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_1000 3574 10377 140 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_10000 7222 13023 70 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[25, 25, 25]_100000 48127 48165 11 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_1000 34732 35295 15 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_10000 43050 51064 12 ADD_POINTS_TO_VOLUMES_cuda:0_100_nearest_[101, 111, 121]_100000 106028 106058 5 -------------------------------------------------------------------------------- ``` Reviewed By: patricklabatut Differential Revision: D29522830 fbshipit-source-id: 1e857db03613b0c6afcb68a58cdd7ba032e1a874
This commit is contained in:
parent
5491b46511
commit
46cf1970ac
@ -12,6 +12,7 @@ from test_points_to_volumes import TestPointsToVolumes
|
||||
|
||||
def bm_points_to_volumes() -> None:
|
||||
case_grid = {
|
||||
"device": ["cpu", "cuda:0"],
|
||||
"batch_size": [10, 100],
|
||||
"interp_mode": ["trilinear", "nearest"],
|
||||
"volume_size": [[25, 25, 25], [101, 111, 121]],
|
||||
|
@ -26,16 +26,14 @@ if DEBUG:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def init_cube_point_cloud(
|
||||
batch_size: int = 10, n_points: int = 100000, rotate_y: bool = True
|
||||
):
|
||||
def init_cube_point_cloud(batch_size: int, n_points: int, device: str, rotate_y: bool):
|
||||
"""
|
||||
Generate a random point cloud of `n_points` whose points
|
||||
are sampled from faces of a 3D cube.
|
||||
"""
|
||||
|
||||
# create the cube mesh batch_size times
|
||||
meshes = TestPointsToVolumes.init_cube_mesh(batch_size)
|
||||
meshes = TestPointsToVolumes.init_cube_mesh(batch_size=batch_size, device=device)
|
||||
|
||||
# generate point clouds by sampling points from the meshes
|
||||
pcl = sample_points_from_meshes(meshes, num_samples=n_points, return_normals=False)
|
||||
@ -66,7 +64,7 @@ def init_cube_point_cloud(
|
||||
|
||||
if rotate_y:
|
||||
# uniformly spaced rotations around y axis
|
||||
R = init_uniform_y_rotations(batch_size=batch_size)
|
||||
R = init_uniform_y_rotations(batch_size=batch_size, device=device)
|
||||
# rotate the point clouds around y axis
|
||||
pcl = torch.bmm(pcl - 0.5, R) + 0.5
|
||||
|
||||
@ -78,6 +76,7 @@ def init_volume_boundary_pointcloud(
|
||||
volume_size: Tuple[int, int, int],
|
||||
n_points: int,
|
||||
interp_mode: str,
|
||||
device: str,
|
||||
require_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -86,7 +85,9 @@ def init_volume_boundary_pointcloud(
|
||||
"""
|
||||
|
||||
# generate a 3D point cloud sampled from sides of a [0,1] cube
|
||||
xyz, rgb = init_cube_point_cloud(batch_size, n_points=n_points, rotate_y=True)
|
||||
xyz, rgb = init_cube_point_cloud(
|
||||
batch_size, n_points=n_points, device=device, rotate_y=True
|
||||
)
|
||||
|
||||
# make volume_size tensor
|
||||
volume_size_t = torch.tensor(volume_size, dtype=xyz.dtype, device=xyz.device)
|
||||
@ -128,12 +129,11 @@ def init_volume_boundary_pointcloud(
|
||||
return pointclouds, initial_volumes
|
||||
|
||||
|
||||
def init_uniform_y_rotations(batch_size: int = 10):
|
||||
def init_uniform_y_rotations(batch_size: int, device: torch.device):
|
||||
"""
|
||||
Generate a batch of `batch_size` 3x3 rotation matrices around y-axis
|
||||
whose angles are uniformly distributed between 0 and 2 pi.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
axis = torch.tensor([0.0, 1.0, 0.0], device=device, dtype=torch.float32)
|
||||
angles = torch.linspace(0, 2.0 * np.pi, batch_size + 1, device=device)
|
||||
angles = angles[:batch_size]
|
||||
@ -153,6 +153,7 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
volume_size: Tuple[int, int, int],
|
||||
n_points: int,
|
||||
interp_mode: str,
|
||||
device: str,
|
||||
):
|
||||
(pointclouds, initial_volumes) = init_volume_boundary_pointcloud(
|
||||
batch_size=batch_size,
|
||||
@ -160,10 +161,14 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
n_points=n_points,
|
||||
interp_mode=interp_mode,
|
||||
require_grad=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _add_points_to_volumes():
|
||||
add_pointclouds_to_volumes(pointclouds, initial_volumes, mode=interp_mode)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return _add_points_to_volumes
|
||||
|
||||
@ -179,12 +184,12 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
return arr3d
|
||||
|
||||
@staticmethod
|
||||
def init_cube_mesh(batch_size: int = 10):
|
||||
def init_cube_mesh(batch_size: int, device: str):
|
||||
"""
|
||||
Generate a batch of `batch_size` cube meshes.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
device = torch.device(device)
|
||||
|
||||
verts, faces = [], []
|
||||
|
||||
@ -255,6 +260,7 @@ class TestPointsToVolumes(TestCaseMixin, unittest.TestCase):
|
||||
interp_mode=interp_mode,
|
||||
batch_size=batch_size,
|
||||
require_grad=True,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
volumes = add_pointclouds_to_volumes(
|
||||
|
@ -151,7 +151,7 @@ def init_cameras(
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# trivial rotations
|
||||
R = init_uniform_y_rotations(batch_size).to(device)
|
||||
R = init_uniform_y_rotations(batch_size=batch_size, device=device)
|
||||
|
||||
# move camera 1.5 m away from the scene center
|
||||
T = torch.zeros((batch_size, 3), device=device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user