mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
save colors as uint8 in PLY
Summary: Allow saving colors as 8bit when writing .ply files. Reviewed By: patricklabatut, nikitos9000 Differential Revision: D30905312 fbshipit-source-id: 44500982c9ed6d6ee901e04f9623e22792a0e7f7
This commit is contained in:
parent
1b1ba5612f
commit
dd76b41014
@ -1067,7 +1067,7 @@ def load_ply(
|
|||||||
return verts, faces
|
return verts, faces
|
||||||
|
|
||||||
|
|
||||||
def _save_ply(
|
def _write_ply_header(
|
||||||
f,
|
f,
|
||||||
*,
|
*,
|
||||||
verts: torch.Tensor,
|
verts: torch.Tensor,
|
||||||
@ -1075,10 +1075,10 @@ def _save_ply(
|
|||||||
verts_normals: Optional[torch.Tensor],
|
verts_normals: Optional[torch.Tensor],
|
||||||
verts_colors: Optional[torch.Tensor],
|
verts_colors: Optional[torch.Tensor],
|
||||||
ascii: bool,
|
ascii: bool,
|
||||||
decimal_places: Optional[int] = None,
|
colors_as_uint8: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Internal implementation for saving 3D data to a .ply file.
|
Internal implementation for writing header when saving to a .ply file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: File object to which the 3D data should be written.
|
f: File object to which the 3D data should be written.
|
||||||
@ -1087,7 +1087,8 @@ def _save_ply(
|
|||||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||||
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
|
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
|
||||||
ascii: (bool) whether to use the ascii ply format.
|
ascii: (bool) whether to use the ascii ply format.
|
||||||
decimal_places: Number of decimal places for saving if ascii=True.
|
colors_as_uint8: Whether to save colors as numbers in the range
|
||||||
|
[0, 255] instead of float32.
|
||||||
"""
|
"""
|
||||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||||
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||||
@ -1113,33 +1114,88 @@ def _save_ply(
|
|||||||
f.write(b"property float ny\n")
|
f.write(b"property float ny\n")
|
||||||
f.write(b"property float nz\n")
|
f.write(b"property float nz\n")
|
||||||
if verts_colors is not None:
|
if verts_colors is not None:
|
||||||
f.write(b"property float red\n")
|
color_ply_type = b"uchar" if colors_as_uint8 else b"float"
|
||||||
f.write(b"property float green\n")
|
for color in (b"red", b"green", b"blue"):
|
||||||
f.write(b"property float blue\n")
|
f.write(b"property " + color_ply_type + b" " + color + b"\n")
|
||||||
if len(verts) and faces is not None:
|
if len(verts) and faces is not None:
|
||||||
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
|
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
|
||||||
f.write(b"property list uchar int vertex_index\n")
|
f.write(b"property list uchar int vertex_index\n")
|
||||||
f.write(b"end_header\n")
|
f.write(b"end_header\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _save_ply(
|
||||||
|
f,
|
||||||
|
*,
|
||||||
|
verts: torch.Tensor,
|
||||||
|
faces: Optional[torch.LongTensor],
|
||||||
|
verts_normals: Optional[torch.Tensor],
|
||||||
|
verts_colors: Optional[torch.Tensor],
|
||||||
|
ascii: bool,
|
||||||
|
decimal_places: Optional[int] = None,
|
||||||
|
colors_as_uint8: bool,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Internal implementation for saving 3D data to a .ply file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f: File object to which the 3D data should be written.
|
||||||
|
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||||
|
faces: LongTensor of shape (F, 3) giving faces.
|
||||||
|
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||||
|
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
|
||||||
|
ascii: (bool) whether to use the ascii ply format.
|
||||||
|
decimal_places: Number of decimal places for saving if ascii=True.
|
||||||
|
colors_as_uint8: Whether to save colors as numbers in the range
|
||||||
|
[0, 255] instead of float32.
|
||||||
|
"""
|
||||||
|
_write_ply_header(
|
||||||
|
f,
|
||||||
|
verts=verts,
|
||||||
|
faces=faces,
|
||||||
|
verts_normals=verts_normals,
|
||||||
|
verts_colors=verts_colors,
|
||||||
|
ascii=ascii,
|
||||||
|
colors_as_uint8=colors_as_uint8,
|
||||||
|
)
|
||||||
|
|
||||||
if not (len(verts)):
|
if not (len(verts)):
|
||||||
warnings.warn("Empty 'verts' provided")
|
warnings.warn("Empty 'verts' provided")
|
||||||
return
|
return
|
||||||
|
|
||||||
verts_tensors = [verts]
|
color_np_type = np.ubyte if colors_as_uint8 else np.float32
|
||||||
|
verts_dtype = [("verts", np.float32, 3)]
|
||||||
if verts_normals is not None:
|
if verts_normals is not None:
|
||||||
verts_tensors.append(verts_normals)
|
verts_dtype.append(("normals", np.float32, 3))
|
||||||
if verts_colors is not None:
|
if verts_colors is not None:
|
||||||
verts_tensors.append(verts_colors)
|
verts_dtype.append(("colors", color_np_type, 3))
|
||||||
|
|
||||||
|
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)
|
||||||
|
vert_data["verts"] = verts.detach().cpu().numpy()
|
||||||
|
if verts_normals is not None:
|
||||||
|
vert_data["normals"] = verts_normals.detach().cpu().numpy()
|
||||||
|
if verts_colors is not None:
|
||||||
|
color_data = verts_colors.detach().cpu().numpy()
|
||||||
|
if colors_as_uint8:
|
||||||
|
vert_data["colors"] = np.rint(color_data * 255)
|
||||||
|
else:
|
||||||
|
vert_data["colors"] = color_data
|
||||||
|
|
||||||
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
|
|
||||||
if ascii:
|
if ascii:
|
||||||
if decimal_places is None:
|
if decimal_places is None:
|
||||||
float_str = "%f"
|
float_str = b"%f"
|
||||||
else:
|
else:
|
||||||
float_str = "%" + ".%df" % decimal_places
|
float_str = b"%" + b".%df" % decimal_places
|
||||||
np.savetxt(f, vert_data, float_str)
|
float_group_str = (float_str + b" ") * 3
|
||||||
|
formats = [float_group_str]
|
||||||
|
if verts_normals is not None:
|
||||||
|
formats.append(float_group_str)
|
||||||
|
if verts_colors is not None:
|
||||||
|
formats.append(b"%d %d %d " if colors_as_uint8 else float_group_str)
|
||||||
|
formats[-1] = formats[-1][:-1] + b"\n"
|
||||||
|
for line_data in vert_data:
|
||||||
|
for data, format in zip(line_data, formats):
|
||||||
|
f.write(format % tuple(data))
|
||||||
else:
|
else:
|
||||||
assert vert_data.dtype == np.float32
|
|
||||||
if isinstance(f, BytesIO):
|
if isinstance(f, BytesIO):
|
||||||
# tofile only works with real files, but is faster than this.
|
# tofile only works with real files, but is faster than this.
|
||||||
f.write(vert_data.tobytes())
|
f.write(vert_data.tobytes())
|
||||||
@ -1189,7 +1245,6 @@ def save_ply(
|
|||||||
ascii: (bool) whether to use the ascii ply format.
|
ascii: (bool) whether to use the ascii ply format.
|
||||||
decimal_places: Number of decimal places for saving if ascii=True.
|
decimal_places: Number of decimal places for saving if ascii=True.
|
||||||
path_manager: PathManager for interpreting f if it is a str.
|
path_manager: PathManager for interpreting f if it is a str.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||||
@ -1227,6 +1282,7 @@ def save_ply(
|
|||||||
verts_colors=None,
|
verts_colors=None,
|
||||||
ascii=ascii,
|
ascii=ascii,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
|
colors_as_uint8=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1272,8 +1328,14 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
|
colors_as_uint8: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Extra optional args:
|
||||||
|
colors_as_uint8: (bool) Whether to save colors as numbers in the
|
||||||
|
range [0, 255] instead of float32.
|
||||||
|
"""
|
||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1307,6 +1369,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
verts_normals=verts_normals,
|
verts_normals=verts_normals,
|
||||||
ascii=binary is False,
|
ascii=binary is False,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
|
colors_as_uint8=colors_as_uint8,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -1342,8 +1405,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
path_manager: PathManager,
|
path_manager: PathManager,
|
||||||
binary: Optional[bool],
|
binary: Optional[bool],
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
|
colors_as_uint8: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Extra optional args:
|
||||||
|
colors_as_uint8: (bool) Whether to save colors as numbers in the
|
||||||
|
range [0, 255] instead of float32.
|
||||||
|
"""
|
||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -1360,5 +1429,6 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
faces=None,
|
faces=None,
|
||||||
ascii=binary is False,
|
ascii=binary is False,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
|
colors_as_uint8=colors_as_uint8,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
@ -528,12 +528,16 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
).encode("ascii")
|
).encode("ascii")
|
||||||
data = struct.pack("<" + "f" * 48, *range(48))
|
data = struct.pack("<" + "f" * 48, *range(48))
|
||||||
points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None]
|
points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None]
|
||||||
features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
|
features_large = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
|
||||||
|
features = features_large / 255.0
|
||||||
|
pointcloud_largefeatures = Pointclouds(
|
||||||
|
points=[points], features=[features_large]
|
||||||
|
)
|
||||||
pointcloud = Pointclouds(points=[points], features=[features])
|
pointcloud = Pointclouds(points=[points], features=[features])
|
||||||
|
|
||||||
io = IO()
|
io = IO()
|
||||||
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
||||||
io.save_pointcloud(data=pointcloud, path=f.name)
|
io.save_pointcloud(data=pointcloud_largefeatures, path=f.name)
|
||||||
f.flush()
|
f.flush()
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
actual_data = f.read()
|
actual_data = f.read()
|
||||||
@ -541,16 +545,53 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(header + data, actual_data)
|
self.assertEqual(header + data, actual_data)
|
||||||
self.assertClose(reloaded_pointcloud.points_list()[0], points)
|
self.assertClose(reloaded_pointcloud.points_list()[0], points)
|
||||||
self.assertClose(reloaded_pointcloud.features_list()[0], features)
|
self.assertClose(reloaded_pointcloud.features_list()[0], features_large)
|
||||||
|
# Test the load-save cycle leaves file completely unchanged
|
||||||
|
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
||||||
|
io.save_pointcloud(
|
||||||
|
data=reloaded_pointcloud,
|
||||||
|
path=f.name,
|
||||||
|
)
|
||||||
|
f.flush()
|
||||||
|
f.seek(0)
|
||||||
|
data2 = f.read()
|
||||||
|
self.assertEqual(data2, actual_data)
|
||||||
|
|
||||||
with NamedTemporaryFile(mode="r", suffix=".ply") as f:
|
with NamedTemporaryFile(mode="r", suffix=".ply") as f:
|
||||||
io.save_pointcloud(data=pointcloud, path=f.name, binary=False)
|
io.save_pointcloud(
|
||||||
|
data=pointcloud, path=f.name, binary=False, decimal_places=9
|
||||||
|
)
|
||||||
reloaded_pointcloud2 = io.load_pointcloud(f.name)
|
reloaded_pointcloud2 = io.load_pointcloud(f.name)
|
||||||
self.assertEqual(f.readline(), "ply\n")
|
self.assertEqual(f.readline(), "ply\n")
|
||||||
self.assertEqual(f.readline(), "format ascii 1.0\n")
|
self.assertEqual(f.readline(), "format ascii 1.0\n")
|
||||||
self.assertClose(reloaded_pointcloud2.points_list()[0], points)
|
self.assertClose(reloaded_pointcloud2.points_list()[0], points)
|
||||||
self.assertClose(reloaded_pointcloud2.features_list()[0], features)
|
self.assertClose(reloaded_pointcloud2.features_list()[0], features)
|
||||||
|
|
||||||
|
for binary in [True, False]:
|
||||||
|
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
||||||
|
io.save_pointcloud(
|
||||||
|
data=pointcloud, path=f.name, colors_as_uint8=True, binary=binary
|
||||||
|
)
|
||||||
|
f.flush()
|
||||||
|
f.seek(0)
|
||||||
|
actual_data = f.read()
|
||||||
|
reloaded_pointcloud3 = io.load_pointcloud(f.name)
|
||||||
|
self.assertClose(reloaded_pointcloud3.features_list()[0], features)
|
||||||
|
self.assertIn(b"property uchar green", actual_data)
|
||||||
|
|
||||||
|
# Test the load-save cycle leaves file completely unchanged
|
||||||
|
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
|
||||||
|
io.save_pointcloud(
|
||||||
|
data=reloaded_pointcloud3,
|
||||||
|
path=f.name,
|
||||||
|
binary=binary,
|
||||||
|
colors_as_uint8=True,
|
||||||
|
)
|
||||||
|
f.flush()
|
||||||
|
f.seek(0)
|
||||||
|
data2 = f.read()
|
||||||
|
self.assertEqual(data2, actual_data)
|
||||||
|
|
||||||
def test_load_pointcloud_bad_order(self):
|
def test_load_pointcloud_bad_order(self):
|
||||||
"""
|
"""
|
||||||
Ply file with a strange property order
|
Ply file with a strange property order
|
||||||
|
Loading…
x
Reference in New Issue
Block a user