From dd76b4101468d61233eff7f240870ab13a8b8662 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 30 Sep 2021 00:47:51 -0700 Subject: [PATCH] save colors as uint8 in PLY Summary: Allow saving colors as 8bit when writing .ply files. Reviewed By: patricklabatut, nikitos9000 Differential Revision: D30905312 fbshipit-source-id: 44500982c9ed6d6ee901e04f9623e22792a0e7f7 --- pytorch3d/io/ply_io.py | 102 ++++++++++++++++++++++++++++++++++------- tests/test_io_ply.py | 49 ++++++++++++++++++-- 2 files changed, 131 insertions(+), 20 deletions(-) diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index fc8386f8..d7070ab8 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -1067,7 +1067,7 @@ def load_ply( return verts, faces -def _save_ply( +def _write_ply_header( f, *, verts: torch.Tensor, @@ -1075,10 +1075,10 @@ def _save_ply( verts_normals: Optional[torch.Tensor], verts_colors: Optional[torch.Tensor], ascii: bool, - decimal_places: Optional[int] = None, + colors_as_uint8: bool, ) -> None: """ - Internal implementation for saving 3D data to a .ply file. + Internal implementation for writing header when saving to a .ply file. Args: f: File object to which the 3D data should be written. @@ -1087,7 +1087,8 @@ def _save_ply( verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_colors: FloatTensor of shape (V, 3) giving vertex colors. ascii: (bool) whether to use the ascii ply format. - decimal_places: Number of decimal places for saving if ascii=True. + colors_as_uint8: Whether to save colors as numbers in the range + [0, 255] instead of float32. """ assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) @@ -1113,33 +1114,88 @@ def _save_ply( f.write(b"property float ny\n") f.write(b"property float nz\n") if verts_colors is not None: - f.write(b"property float red\n") - f.write(b"property float green\n") - f.write(b"property float blue\n") + color_ply_type = b"uchar" if colors_as_uint8 else b"float" + for color in (b"red", b"green", b"blue"): + f.write(b"property " + color_ply_type + b" " + color + b"\n") if len(verts) and faces is not None: f.write(f"element face {faces.shape[0]}\n".encode("ascii")) f.write(b"property list uchar int vertex_index\n") f.write(b"end_header\n") + +def _save_ply( + f, + *, + verts: torch.Tensor, + faces: Optional[torch.LongTensor], + verts_normals: Optional[torch.Tensor], + verts_colors: Optional[torch.Tensor], + ascii: bool, + decimal_places: Optional[int] = None, + colors_as_uint8: bool, +) -> None: + """ + Internal implementation for saving 3D data to a .ply file. + + Args: + f: File object to which the 3D data should be written. + verts: FloatTensor of shape (V, 3) giving vertex coordinates. + faces: LongTensor of shape (F, 3) giving faces. + verts_normals: FloatTensor of shape (V, 3) giving vertex normals. + verts_colors: FloatTensor of shape (V, 3) giving vertex colors. + ascii: (bool) whether to use the ascii ply format. + decimal_places: Number of decimal places for saving if ascii=True. + colors_as_uint8: Whether to save colors as numbers in the range + [0, 255] instead of float32. + """ + _write_ply_header( + f, + verts=verts, + faces=faces, + verts_normals=verts_normals, + verts_colors=verts_colors, + ascii=ascii, + colors_as_uint8=colors_as_uint8, + ) + if not (len(verts)): warnings.warn("Empty 'verts' provided") return - verts_tensors = [verts] + color_np_type = np.ubyte if colors_as_uint8 else np.float32 + verts_dtype = [("verts", np.float32, 3)] if verts_normals is not None: - verts_tensors.append(verts_normals) + verts_dtype.append(("normals", np.float32, 3)) if verts_colors is not None: - verts_tensors.append(verts_colors) + verts_dtype.append(("colors", color_np_type, 3)) + + vert_data = np.zeros(verts.shape[0], dtype=verts_dtype) + vert_data["verts"] = verts.detach().cpu().numpy() + if verts_normals is not None: + vert_data["normals"] = verts_normals.detach().cpu().numpy() + if verts_colors is not None: + color_data = verts_colors.detach().cpu().numpy() + if colors_as_uint8: + vert_data["colors"] = np.rint(color_data * 255) + else: + vert_data["colors"] = color_data - vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy() if ascii: if decimal_places is None: - float_str = "%f" + float_str = b"%f" else: - float_str = "%" + ".%df" % decimal_places - np.savetxt(f, vert_data, float_str) + float_str = b"%" + b".%df" % decimal_places + float_group_str = (float_str + b" ") * 3 + formats = [float_group_str] + if verts_normals is not None: + formats.append(float_group_str) + if verts_colors is not None: + formats.append(b"%d %d %d " if colors_as_uint8 else float_group_str) + formats[-1] = formats[-1][:-1] + b"\n" + for line_data in vert_data: + for data, format in zip(line_data, formats): + f.write(format % tuple(data)) else: - assert vert_data.dtype == np.float32 if isinstance(f, BytesIO): # tofile only works with real files, but is faster than this. f.write(vert_data.tobytes()) @@ -1189,7 +1245,6 @@ def save_ply( ascii: (bool) whether to use the ascii ply format. decimal_places: Number of decimal places for saving if ascii=True. path_manager: PathManager for interpreting f if it is a str. - """ if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): @@ -1227,6 +1282,7 @@ def save_ply( verts_colors=None, ascii=ascii, decimal_places=decimal_places, + colors_as_uint8=False, ) @@ -1272,8 +1328,14 @@ class MeshPlyFormat(MeshFormatInterpreter): path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, + colors_as_uint8: bool = False, **kwargs, ) -> bool: + """ + Extra optional args: + colors_as_uint8: (bool) Whether to save colors as numbers in the + range [0, 255] instead of float32. + """ if not endswith(path, self.known_suffixes): return False @@ -1307,6 +1369,7 @@ class MeshPlyFormat(MeshFormatInterpreter): verts_normals=verts_normals, ascii=binary is False, decimal_places=decimal_places, + colors_as_uint8=colors_as_uint8, ) return True @@ -1342,8 +1405,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): path_manager: PathManager, binary: Optional[bool], decimal_places: Optional[int] = None, + colors_as_uint8: bool = False, **kwargs, ) -> bool: + """ + Extra optional args: + colors_as_uint8: (bool) Whether to save colors as numbers in the + range [0, 255] instead of float32. + """ if not endswith(path, self.known_suffixes): return False @@ -1360,5 +1429,6 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): faces=None, ascii=binary is False, decimal_places=decimal_places, + colors_as_uint8=colors_as_uint8, ) return True diff --git a/tests/test_io_ply.py b/tests/test_io_ply.py index 95c77590..f9f8d30d 100644 --- a/tests/test_io_ply.py +++ b/tests/test_io_ply.py @@ -528,12 +528,16 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ).encode("ascii") data = struct.pack("<" + "f" * 48, *range(48)) points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None] - features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None] + features_large = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None] + features = features_large / 255.0 + pointcloud_largefeatures = Pointclouds( + points=[points], features=[features_large] + ) pointcloud = Pointclouds(points=[points], features=[features]) io = IO() with NamedTemporaryFile(mode="rb", suffix=".ply") as f: - io.save_pointcloud(data=pointcloud, path=f.name) + io.save_pointcloud(data=pointcloud_largefeatures, path=f.name) f.flush() f.seek(0) actual_data = f.read() @@ -541,16 +545,53 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self.assertEqual(header + data, actual_data) self.assertClose(reloaded_pointcloud.points_list()[0], points) - self.assertClose(reloaded_pointcloud.features_list()[0], features) + self.assertClose(reloaded_pointcloud.features_list()[0], features_large) + # Test the load-save cycle leaves file completely unchanged + with NamedTemporaryFile(mode="rb", suffix=".ply") as f: + io.save_pointcloud( + data=reloaded_pointcloud, + path=f.name, + ) + f.flush() + f.seek(0) + data2 = f.read() + self.assertEqual(data2, actual_data) with NamedTemporaryFile(mode="r", suffix=".ply") as f: - io.save_pointcloud(data=pointcloud, path=f.name, binary=False) + io.save_pointcloud( + data=pointcloud, path=f.name, binary=False, decimal_places=9 + ) reloaded_pointcloud2 = io.load_pointcloud(f.name) self.assertEqual(f.readline(), "ply\n") self.assertEqual(f.readline(), "format ascii 1.0\n") self.assertClose(reloaded_pointcloud2.points_list()[0], points) self.assertClose(reloaded_pointcloud2.features_list()[0], features) + for binary in [True, False]: + with NamedTemporaryFile(mode="rb", suffix=".ply") as f: + io.save_pointcloud( + data=pointcloud, path=f.name, colors_as_uint8=True, binary=binary + ) + f.flush() + f.seek(0) + actual_data = f.read() + reloaded_pointcloud3 = io.load_pointcloud(f.name) + self.assertClose(reloaded_pointcloud3.features_list()[0], features) + self.assertIn(b"property uchar green", actual_data) + + # Test the load-save cycle leaves file completely unchanged + with NamedTemporaryFile(mode="rb", suffix=".ply") as f: + io.save_pointcloud( + data=reloaded_pointcloud3, + path=f.name, + binary=binary, + colors_as_uint8=True, + ) + f.flush() + f.seek(0) + data2 = f.read() + self.assertEqual(data2, actual_data) + def test_load_pointcloud_bad_order(self): """ Ply file with a strange property order