save colors as uint8 in PLY

Summary: Allow saving colors as 8bit when writing .ply files.

Reviewed By: patricklabatut, nikitos9000

Differential Revision: D30905312

fbshipit-source-id: 44500982c9ed6d6ee901e04f9623e22792a0e7f7
This commit is contained in:
Jeremy Reizenstein
2021-09-30 00:47:51 -07:00
committed by Facebook GitHub Bot
parent 1b1ba5612f
commit dd76b41014
2 changed files with 131 additions and 20 deletions

View File

@@ -1067,7 +1067,7 @@ def load_ply(
return verts, faces
def _save_ply(
def _write_ply_header(
f,
*,
verts: torch.Tensor,
@@ -1075,10 +1075,10 @@ def _save_ply(
verts_normals: Optional[torch.Tensor],
verts_colors: Optional[torch.Tensor],
ascii: bool,
decimal_places: Optional[int] = None,
colors_as_uint8: bool,
) -> None:
"""
Internal implementation for saving 3D data to a .ply file.
Internal implementation for writing header when saving to a .ply file.
Args:
f: File object to which the 3D data should be written.
@@ -1087,7 +1087,8 @@ def _save_ply(
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
colors_as_uint8: Whether to save colors as numbers in the range
[0, 255] instead of float32.
"""
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
@@ -1113,33 +1114,88 @@ def _save_ply(
f.write(b"property float ny\n")
f.write(b"property float nz\n")
if verts_colors is not None:
f.write(b"property float red\n")
f.write(b"property float green\n")
f.write(b"property float blue\n")
color_ply_type = b"uchar" if colors_as_uint8 else b"float"
for color in (b"red", b"green", b"blue"):
f.write(b"property " + color_ply_type + b" " + color + b"\n")
if len(verts) and faces is not None:
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
def _save_ply(
f,
*,
verts: torch.Tensor,
faces: Optional[torch.LongTensor],
verts_normals: Optional[torch.Tensor],
verts_colors: Optional[torch.Tensor],
ascii: bool,
decimal_places: Optional[int] = None,
colors_as_uint8: bool,
) -> None:
"""
Internal implementation for saving 3D data to a .ply file.
Args:
f: File object to which the 3D data should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
colors_as_uint8: Whether to save colors as numbers in the range
[0, 255] instead of float32.
"""
_write_ply_header(
f,
verts=verts,
faces=faces,
verts_normals=verts_normals,
verts_colors=verts_colors,
ascii=ascii,
colors_as_uint8=colors_as_uint8,
)
if not (len(verts)):
warnings.warn("Empty 'verts' provided")
return
verts_tensors = [verts]
color_np_type = np.ubyte if colors_as_uint8 else np.float32
verts_dtype = [("verts", np.float32, 3)]
if verts_normals is not None:
verts_tensors.append(verts_normals)
verts_dtype.append(("normals", np.float32, 3))
if verts_colors is not None:
verts_tensors.append(verts_colors)
verts_dtype.append(("colors", color_np_type, 3))
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)
vert_data["verts"] = verts.detach().cpu().numpy()
if verts_normals is not None:
vert_data["normals"] = verts_normals.detach().cpu().numpy()
if verts_colors is not None:
color_data = verts_colors.detach().cpu().numpy()
if colors_as_uint8:
vert_data["colors"] = np.rint(color_data * 255)
else:
vert_data["colors"] = color_data
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
if ascii:
if decimal_places is None:
float_str = "%f"
float_str = b"%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, vert_data, float_str)
float_str = b"%" + b".%df" % decimal_places
float_group_str = (float_str + b" ") * 3
formats = [float_group_str]
if verts_normals is not None:
formats.append(float_group_str)
if verts_colors is not None:
formats.append(b"%d %d %d " if colors_as_uint8 else float_group_str)
formats[-1] = formats[-1][:-1] + b"\n"
for line_data in vert_data:
for data, format in zip(line_data, formats):
f.write(format % tuple(data))
else:
assert vert_data.dtype == np.float32
if isinstance(f, BytesIO):
# tofile only works with real files, but is faster than this.
f.write(vert_data.tobytes())
@@ -1189,7 +1245,6 @@ def save_ply(
ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True.
path_manager: PathManager for interpreting f if it is a str.
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
@@ -1227,6 +1282,7 @@ def save_ply(
verts_colors=None,
ascii=ascii,
decimal_places=decimal_places,
colors_as_uint8=False,
)
@@ -1272,8 +1328,14 @@ class MeshPlyFormat(MeshFormatInterpreter):
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
colors_as_uint8: bool = False,
**kwargs,
) -> bool:
"""
Extra optional args:
colors_as_uint8: (bool) Whether to save colors as numbers in the
range [0, 255] instead of float32.
"""
if not endswith(path, self.known_suffixes):
return False
@@ -1307,6 +1369,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
verts_normals=verts_normals,
ascii=binary is False,
decimal_places=decimal_places,
colors_as_uint8=colors_as_uint8,
)
return True
@@ -1342,8 +1405,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
path_manager: PathManager,
binary: Optional[bool],
decimal_places: Optional[int] = None,
colors_as_uint8: bool = False,
**kwargs,
) -> bool:
"""
Extra optional args:
colors_as_uint8: (bool) Whether to save colors as numbers in the
range [0, 255] instead of float32.
"""
if not endswith(path, self.known_suffixes):
return False
@@ -1360,5 +1429,6 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
faces=None,
ascii=binary is False,
decimal_places=decimal_places,
colors_as_uint8=colors_as_uint8,
)
return True