mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
PLY TexturesVertex loading
Summary: Include TexturesVertex colors when loading and saving Meshes to PLY files. A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data. Reviewed By: gkioxari Differential Revision: D27765260 fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
This commit is contained in:
parent
097b0ef2c6
commit
6c3fe952d1
@ -20,6 +20,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||||
|
from pytorch3d.renderer import TexturesVertex
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
|
|
||||||
from .pluggable_formats import (
|
from .pluggable_formats import (
|
||||||
@ -66,7 +67,7 @@ class _PlyElementType:
|
|||||||
def __init__(self, name: str, count: int):
|
def __init__(self, name: str, count: int):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.count = count
|
self.count = count
|
||||||
self.properties = []
|
self.properties: List[_Property] = []
|
||||||
|
|
||||||
def add_property(
|
def add_property(
|
||||||
self, name: str, data_type: str, list_size_type: Optional[str] = None
|
self, name: str, data_type: str, list_size_type: Optional[str] = None
|
||||||
@ -142,7 +143,7 @@ class _PlyHeader:
|
|||||||
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
|
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
|
||||||
raise ValueError("Invalid file header.")
|
raise ValueError("Invalid file header.")
|
||||||
seen_format = False
|
seen_format = False
|
||||||
self.elements = []
|
self.elements: List[_PlyElementType] = []
|
||||||
self.obj_info = []
|
self.obj_info = []
|
||||||
while True:
|
while True:
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
@ -891,8 +892,8 @@ def _get_verts(
|
|||||||
|
|
||||||
|
|
||||||
def _load_ply(
|
def _load_ply(
|
||||||
f, *, path_manager: PathManager, return_vertex_colors: bool = False
|
f, *, path_manager: PathManager
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Load the data from a .ply file.
|
Load the data from a .ply file.
|
||||||
|
|
||||||
@ -903,12 +904,11 @@ def _load_ply(
|
|||||||
ply format, then a text stream is not supported.
|
ply format, then a text stream is not supported.
|
||||||
It is easiest to use a binary stream in all cases.
|
It is easiest to use a binary stream in all cases.
|
||||||
path_manager: PathManager for loading if f is a str.
|
path_manager: PathManager for loading if f is a str.
|
||||||
return_vertex_colors: whether to return vertex colors.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
verts: FloatTensor of shape (V, 3).
|
verts: FloatTensor of shape (V, 3).
|
||||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||||
vertex_colors: None or FloatTensor of shape (V, 3), only if requested
|
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||||
"""
|
"""
|
||||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||||
|
|
||||||
@ -950,16 +950,17 @@ def _load_ply(
|
|||||||
if faces is not None:
|
if faces is not None:
|
||||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||||
|
|
||||||
if return_vertex_colors:
|
return verts, faces, vertex_colors
|
||||||
return verts, faces, vertex_colors
|
|
||||||
return verts, faces, None
|
|
||||||
|
|
||||||
|
|
||||||
def load_ply(
|
def load_ply(
|
||||||
f, *, path_manager: Optional[PathManager] = None
|
f, *, path_manager: Optional[PathManager] = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Load the data from a .ply file.
|
Load the verts and faces from a .ply file.
|
||||||
|
Note that the preferred way to load data from such a file
|
||||||
|
is to use the IO.load_mesh and IO.load_pointcloud functions,
|
||||||
|
which can read more of the data.
|
||||||
|
|
||||||
Example .ply file format:
|
Example .ply file format:
|
||||||
|
|
||||||
@ -1016,8 +1017,8 @@ def _save_ply(
|
|||||||
*,
|
*,
|
||||||
verts: torch.Tensor,
|
verts: torch.Tensor,
|
||||||
faces: Optional[torch.LongTensor],
|
faces: Optional[torch.LongTensor],
|
||||||
verts_normals: torch.Tensor,
|
verts_normals: Optional[torch.Tensor],
|
||||||
verts_colors: torch.Tensor,
|
verts_colors: Optional[torch.Tensor],
|
||||||
ascii: bool,
|
ascii: bool,
|
||||||
decimal_places: Optional[int] = None,
|
decimal_places: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1029,16 +1030,16 @@ def _save_ply(
|
|||||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||||
faces: LongTensor of shape (F, 3) giving faces.
|
faces: LongTensor of shape (F, 3) giving faces.
|
||||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||||
|
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
|
||||||
ascii: (bool) whether to use the ascii ply format.
|
ascii: (bool) whether to use the ascii ply format.
|
||||||
decimal_places: Number of decimal places for saving if ascii=True.
|
decimal_places: Number of decimal places for saving if ascii=True.
|
||||||
"""
|
"""
|
||||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||||
if faces is not None:
|
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
assert verts_normals is None or (
|
||||||
assert not len(verts_normals) or (
|
|
||||||
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
||||||
)
|
)
|
||||||
assert not len(verts_colors) or (
|
assert verts_colors is None or (
|
||||||
verts_colors.dim() == 2 and verts_colors.size(1) == 3
|
verts_colors.dim() == 2 and verts_colors.size(1) == 3
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1052,11 +1053,11 @@ def _save_ply(
|
|||||||
f.write(b"property float x\n")
|
f.write(b"property float x\n")
|
||||||
f.write(b"property float y\n")
|
f.write(b"property float y\n")
|
||||||
f.write(b"property float z\n")
|
f.write(b"property float z\n")
|
||||||
if verts_normals.numel() > 0:
|
if verts_normals is not None:
|
||||||
f.write(b"property float nx\n")
|
f.write(b"property float nx\n")
|
||||||
f.write(b"property float ny\n")
|
f.write(b"property float ny\n")
|
||||||
f.write(b"property float nz\n")
|
f.write(b"property float nz\n")
|
||||||
if verts_colors.numel() > 0:
|
if verts_colors is not None:
|
||||||
f.write(b"property float red\n")
|
f.write(b"property float red\n")
|
||||||
f.write(b"property float green\n")
|
f.write(b"property float green\n")
|
||||||
f.write(b"property float blue\n")
|
f.write(b"property float blue\n")
|
||||||
@ -1069,7 +1070,13 @@ def _save_ply(
|
|||||||
warnings.warn("Empty 'verts' provided")
|
warnings.warn("Empty 'verts' provided")
|
||||||
return
|
return
|
||||||
|
|
||||||
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy()
|
verts_tensors = [verts]
|
||||||
|
if verts_normals is not None:
|
||||||
|
verts_tensors.append(verts_normals)
|
||||||
|
if verts_colors is not None:
|
||||||
|
verts_tensors.append(verts_colors)
|
||||||
|
|
||||||
|
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
|
||||||
if ascii:
|
if ascii:
|
||||||
if decimal_places is None:
|
if decimal_places is None:
|
||||||
float_str = "%f"
|
float_str = "%f"
|
||||||
@ -1085,7 +1092,7 @@ def _save_ply(
|
|||||||
vert_data.tofile(f)
|
vert_data.tofile(f)
|
||||||
|
|
||||||
if faces is not None:
|
if faces is not None:
|
||||||
faces_array = faces.detach().numpy()
|
faces_array = faces.detach().cpu().numpy()
|
||||||
|
|
||||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||||
|
|
||||||
@ -1125,12 +1132,6 @@ def save_ply(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
verts_normals = (
|
|
||||||
torch.tensor([], dtype=torch.float32, device=verts.device)
|
|
||||||
if verts_normals is None
|
|
||||||
else verts_normals
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
@ -1143,16 +1144,18 @@ def save_ply(
|
|||||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
|
|
||||||
if len(verts_normals) and not (
|
if (
|
||||||
verts_normals.dim() == 2
|
verts_normals is not None
|
||||||
and verts_normals.size(1) == 3
|
and len(verts_normals)
|
||||||
and verts_normals.size(0) == verts.size(0)
|
and not (
|
||||||
|
verts_normals.dim() == 2
|
||||||
|
and verts_normals.size(1) == 3
|
||||||
|
and verts_normals.size(0) == verts.size(0)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||||
raise ValueError(message)
|
raise ValueError(message)
|
||||||
|
|
||||||
verts_colors = torch.FloatTensor([])
|
|
||||||
|
|
||||||
if path_manager is None:
|
if path_manager is None:
|
||||||
path_manager = PathManager()
|
path_manager = PathManager()
|
||||||
with _open_file(f, path_manager, "wb") as f:
|
with _open_file(f, path_manager, "wb") as f:
|
||||||
@ -1161,7 +1164,7 @@ def save_ply(
|
|||||||
verts=verts,
|
verts=verts,
|
||||||
faces=faces,
|
faces=faces,
|
||||||
verts_normals=verts_normals,
|
verts_normals=verts_normals,
|
||||||
verts_colors=verts_colors,
|
verts_colors=None,
|
||||||
ascii=ascii,
|
ascii=ascii,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
)
|
)
|
||||||
@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
verts, faces = load_ply(f=path, path_manager=path_manager)
|
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
|
||||||
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
|
if faces is None:
|
||||||
|
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||||
|
|
||||||
|
textures = None
|
||||||
|
if include_textures and verts_colors is not None:
|
||||||
|
textures = TexturesVertex([verts_colors.to(device)])
|
||||||
|
|
||||||
|
mesh = Meshes(
|
||||||
|
verts=[verts.to(device)],
|
||||||
|
faces=[faces.to(device)],
|
||||||
|
textures=textures,
|
||||||
|
)
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
@ -1201,14 +1215,30 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
|||||||
# TODO: normals are not saved. We only want to save them if they already exist.
|
# TODO: normals are not saved. We only want to save them if they already exist.
|
||||||
verts = data.verts_list()[0]
|
verts = data.verts_list()[0]
|
||||||
faces = data.faces_list()[0]
|
faces = data.faces_list()[0]
|
||||||
save_ply(
|
|
||||||
f=path,
|
if isinstance(data.textures, TexturesVertex):
|
||||||
verts=verts,
|
mesh_verts_colors = data.textures.verts_features_list()[0]
|
||||||
faces=faces,
|
n_colors = mesh_verts_colors.shape[1]
|
||||||
ascii=binary is False,
|
if n_colors == 3:
|
||||||
decimal_places=decimal_places,
|
verts_colors = mesh_verts_colors
|
||||||
path_manager=path_manager,
|
else:
|
||||||
)
|
warnings.warn(
|
||||||
|
f"Texture will not be saved as it has {n_colors} colors, not 3."
|
||||||
|
)
|
||||||
|
verts_colors = None
|
||||||
|
else:
|
||||||
|
verts_colors = None
|
||||||
|
|
||||||
|
with _open_file(path, path_manager, "wb") as f:
|
||||||
|
_save_ply(
|
||||||
|
f=f,
|
||||||
|
verts=verts,
|
||||||
|
faces=faces,
|
||||||
|
verts_colors=verts_colors,
|
||||||
|
verts_normals=None,
|
||||||
|
ascii=binary is False,
|
||||||
|
decimal_places=decimal_places,
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
if not endswith(path, self.known_suffixes):
|
if not endswith(path, self.known_suffixes):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
verts, faces, features = _load_ply(
|
verts, faces, features = _load_ply(f=path, path_manager=path_manager)
|
||||||
f=path, path_manager=path_manager, return_vertex_colors=True
|
|
||||||
)
|
|
||||||
verts = verts.to(device)
|
verts = verts.to(device)
|
||||||
if features is None:
|
if features is not None:
|
||||||
pointcloud = Pointclouds(points=[verts])
|
features = [features.to(device)]
|
||||||
else:
|
|
||||||
pointcloud = Pointclouds(points=[verts], features=[features.to(device)])
|
pointcloud = Pointclouds(points=[verts], features=features)
|
||||||
return pointcloud
|
return pointcloud
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
points = data.points_list()[0]
|
points = data.points_list()[0]
|
||||||
features = data.features_list()[0]
|
features = data.features_packed()
|
||||||
|
|
||||||
with _open_file(path, path_manager, "wb") as f:
|
with _open_file(path, path_manager, "wb") as f:
|
||||||
_save_ply(
|
_save_ply(
|
||||||
f=f,
|
f=f,
|
||||||
verts=points,
|
verts=points,
|
||||||
verts_colors=features,
|
verts_colors=features,
|
||||||
verts_normals=torch.FloatTensor([]),
|
verts_normals=None,
|
||||||
faces=None,
|
faces=None,
|
||||||
ascii=binary is False,
|
ascii=binary is False,
|
||||||
decimal_places=decimal_places,
|
decimal_places=decimal_places,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import itertools
|
||||||
import struct
|
import struct
|
||||||
import unittest
|
import unittest
|
||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
@ -12,7 +13,8 @@ from common_testing import TestCaseMixin
|
|||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.io.ply_io import load_ply, save_ply
|
from pytorch3d.io.ply_io import load_ply, save_ply
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.renderer.mesh import TexturesVertex
|
||||||
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
from pytorch3d.utils import torus
|
from pytorch3d.utils import torus
|
||||||
|
|
||||||
|
|
||||||
@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
|||||||
):
|
):
|
||||||
io.load_mesh(f3.name)
|
io.load_mesh(f3.name)
|
||||||
|
|
||||||
|
def test_save_too_many_colors(self):
|
||||||
|
verts = torch.tensor(
|
||||||
|
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||||
|
vert_colors = torch.rand((4, 7))
|
||||||
|
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
|
||||||
|
|
||||||
|
mesh = Meshes(
|
||||||
|
verts=[verts],
|
||||||
|
faces=[faces],
|
||||||
|
textures=texture_with_seven_colors,
|
||||||
|
)
|
||||||
|
|
||||||
|
io = IO()
|
||||||
|
msg = "Texture will not be saved as it has 7 colors, not 3."
|
||||||
|
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
|
||||||
|
with self.assertWarnsRegex(UserWarning, msg):
|
||||||
|
io.save_mesh(mesh.cuda(), f.name)
|
||||||
|
|
||||||
|
def test_save_load_meshes(self):
|
||||||
|
verts = torch.tensor(
|
||||||
|
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||||
|
)
|
||||||
|
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||||
|
vert_colors = torch.rand_like(verts)
|
||||||
|
texture = TexturesVertex(verts_features=[vert_colors])
|
||||||
|
|
||||||
|
for do_textures in itertools.product([True, False]):
|
||||||
|
mesh = Meshes(
|
||||||
|
verts=[verts],
|
||||||
|
faces=[faces],
|
||||||
|
textures=texture if do_textures else None,
|
||||||
|
)
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
io = IO()
|
||||||
|
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
|
||||||
|
io.save_mesh(mesh.cuda(), f.name)
|
||||||
|
f.flush()
|
||||||
|
mesh2 = io.load_mesh(f.name, device=device)
|
||||||
|
self.assertEqual(mesh2.device, device)
|
||||||
|
mesh2 = mesh2.cpu()
|
||||||
|
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
|
||||||
|
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
|
||||||
|
if do_textures:
|
||||||
|
self.assertIsInstance(mesh2.textures, TexturesVertex)
|
||||||
|
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
|
||||||
|
else:
|
||||||
|
self.assertIsNone(mesh2.textures)
|
||||||
|
|
||||||
def test_save_ply_invalid_shapes(self):
|
def test_save_ply_invalid_shapes(self):
|
||||||
# Invalid vertices shape
|
# Invalid vertices shape
|
||||||
with self.assertRaises(ValueError) as error:
|
with self.assertRaises(ValueError) as error:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user