diff --git a/tests/bm_mesh_io.py b/tests/bm_mesh_io.py index fabdf0e5..33b0861d 100644 --- a/tests/bm_mesh_io.py +++ b/tests/bm_mesh_io.py @@ -36,3 +36,29 @@ def bm_save_load() -> None: simple_kwargs_list, warmup_iters=1, ) + + complex_kwargs_list = [{"N": 8}, {"N": 32}, {"N": 128}] + benchmark( + TestMeshObjIO.bm_load_complex_obj, + "LOAD_COMPLEX_OBJ", + complex_kwargs_list, + warmup_iters=1, + ) + benchmark( + TestMeshObjIO.bm_save_complex_obj, + "SAVE_COMPLEX_OBJ", + complex_kwargs_list, + warmup_iters=1, + ) + benchmark( + TestMeshPlyIO.bm_load_complex_ply, + "LOAD_COMPLEX_PLY", + complex_kwargs_list, + warmup_iters=1, + ) + benchmark( + TestMeshPlyIO.bm_save_complex_ply, + "SAVE_COMPLEX_PLY", + complex_kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 5f7fb4e9..bcfaf90b 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -8,6 +8,7 @@ import torch from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj from pytorch3d.structures import Meshes, Textures, join_meshes +from pytorch3d.utils import torus from common_testing import TestCaseMixin @@ -601,15 +602,42 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0]) self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0]) + @staticmethod + def _bm_save_obj( + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int + ): + return lambda: save_obj(StringIO(), verts, faces, decimal_places) + + @staticmethod + def _bm_load_obj( + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int + ): + f = StringIO() + save_obj(f, verts, faces, decimal_places) + s = f.getvalue() + # Recreate stream so it's unaffected by how it was created. + return lambda: load_obj(StringIO(s)) + @staticmethod def bm_save_simple_obj_with_init(V: int, F: int): - verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) - faces_list = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) - return lambda: save_obj( - StringIO(), verts_list, faces_list, decimal_places=2 - ) + verts = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) + faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) + return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=2) @staticmethod def bm_load_simple_obj_with_init(V: int, F: int): - obj = "\n".join(["v 0.1 0.2 0.3"] * V + ["f 1 2 3"] * F) - return lambda: load_obj(StringIO(obj)) + verts = torch.tensor(V * [[0.1, 0.2, 0.3]]).view(-1, 3) + faces = torch.tensor(F * [[1, 2, 3]]).view(-1, 3) + return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=2) + + @staticmethod + def bm_save_complex_obj(N: int): + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) + [verts], [faces] = meshes.verts_list(), meshes.faces_list() + return TestMeshObjIO._bm_save_obj(verts, faces, decimal_places=5) + + @staticmethod + def bm_load_complex_obj(N: int): + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) + [verts], [faces] = meshes.verts_list(), meshes.faces_list() + return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5) diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index d35b239f..3568ccca 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -6,6 +6,7 @@ from io import BytesIO, StringIO import torch from pytorch3d.io.ply_io import _load_ply_raw, load_ply, save_ply +from pytorch3d.utils import torus from common_testing import TestCaseMixin @@ -407,19 +408,43 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): load_ply(StringIO("\n".join(lines2))) @staticmethod - def bm_save_simple_ply_with_init(V: int, F: int): - verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) - faces_list = torch.tensor(F * [[0, 1, 2]]).view(-1, 3) + def _bm_save_ply( + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int + ): return lambda: save_ply( - StringIO(), verts_list, faces_list, decimal_places=2 + StringIO(), verts, faces, decimal_places=decimal_places ) + @staticmethod + def _bm_load_ply( + verts: torch.Tensor, faces: torch.Tensor, decimal_places: int + ): + f = StringIO() + save_ply(f, verts, faces, decimal_places) + s = f.getvalue() + # Recreate stream so it's unaffected by how it was created. + return lambda: load_ply(StringIO(s)) + + @staticmethod + def bm_save_simple_ply_with_init(V: int, F: int): + verts = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3) + faces = torch.tensor(F * [[0, 1, 2]]).view(-1, 3) + return TestMeshPlyIO._bm_save_ply(verts, faces, decimal_places=2) + @staticmethod def bm_load_simple_ply_with_init(V: int, F: int): verts = torch.tensor([[0.1, 0.2, 0.3]]).expand(V, 3) faces = torch.tensor([[0, 1, 2]], dtype=torch.int64).expand(F, 3) - ply_file = StringIO() - save_ply(ply_file, verts=verts, faces=faces) - ply = ply_file.getvalue() - # Recreate stream so it's unaffected by how it was created. - return lambda: load_ply(StringIO(ply)) + return TestMeshPlyIO._bm_load_ply(verts, faces, decimal_places=2) + + @staticmethod + def bm_save_complex_ply(N: int): + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) + [verts], [faces] = meshes.verts_list(), meshes.faces_list() + return TestMeshPlyIO._bm_save_ply(verts, faces, decimal_places=5) + + @staticmethod + def bm_load_complex_ply(N: int): + meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) + [verts], [faces] = meshes.verts_list(), meshes.faces_list() + return TestMeshPlyIO._bm_load_ply(verts, faces, decimal_places=5)