mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
PLY TexturesVertex loading
Summary: Include TexturesVertex colors when loading and saving Meshes to PLY files. A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data. Reviewed By: gkioxari Differential Revision: D27765260 fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
This commit is contained in:
parent
097b0ef2c6
commit
6c3fe952d1
@ -20,6 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
|
||||
from pytorch3d.renderer import TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .pluggable_formats import (
|
||||
@ -66,7 +67,7 @@ class _PlyElementType:
|
||||
def __init__(self, name: str, count: int):
|
||||
self.name = name
|
||||
self.count = count
|
||||
self.properties = []
|
||||
self.properties: List[_Property] = []
|
||||
|
||||
def add_property(
|
||||
self, name: str, data_type: str, list_size_type: Optional[str] = None
|
||||
@ -142,7 +143,7 @@ class _PlyHeader:
|
||||
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
|
||||
raise ValueError("Invalid file header.")
|
||||
seen_format = False
|
||||
self.elements = []
|
||||
self.elements: List[_PlyElementType] = []
|
||||
self.obj_info = []
|
||||
while True:
|
||||
line = f.readline()
|
||||
@ -891,8 +892,8 @@ def _get_verts(
|
||||
|
||||
|
||||
def _load_ply(
|
||||
f, *, path_manager: PathManager, return_vertex_colors: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
f, *, path_manager: PathManager
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
|
||||
@ -903,12 +904,11 @@ def _load_ply(
|
||||
ply format, then a text stream is not supported.
|
||||
It is easiest to use a binary stream in all cases.
|
||||
path_manager: PathManager for loading if f is a str.
|
||||
return_vertex_colors: whether to return vertex colors.
|
||||
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3), only if requested
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||
|
||||
@ -950,16 +950,17 @@ def _load_ply(
|
||||
if faces is not None:
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
if return_vertex_colors:
|
||||
return verts, faces, vertex_colors
|
||||
return verts, faces, None
|
||||
|
||||
|
||||
def load_ply(
|
||||
f, *, path_manager: Optional[PathManager] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
Load the verts and faces from a .ply file.
|
||||
Note that the preferred way to load data from such a file
|
||||
is to use the IO.load_mesh and IO.load_pointcloud functions,
|
||||
which can read more of the data.
|
||||
|
||||
Example .ply file format:
|
||||
|
||||
@ -1016,8 +1017,8 @@ def _save_ply(
|
||||
*,
|
||||
verts: torch.Tensor,
|
||||
faces: Optional[torch.LongTensor],
|
||||
verts_normals: torch.Tensor,
|
||||
verts_colors: torch.Tensor,
|
||||
verts_normals: Optional[torch.Tensor],
|
||||
verts_colors: Optional[torch.Tensor],
|
||||
ascii: bool,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
@ -1029,16 +1030,16 @@ def _save_ply(
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
|
||||
ascii: (bool) whether to use the ascii ply format.
|
||||
decimal_places: Number of decimal places for saving if ascii=True.
|
||||
"""
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
if faces is not None:
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
assert not len(verts_normals) or (
|
||||
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
assert verts_normals is None or (
|
||||
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
||||
)
|
||||
assert not len(verts_colors) or (
|
||||
assert verts_colors is None or (
|
||||
verts_colors.dim() == 2 and verts_colors.size(1) == 3
|
||||
)
|
||||
|
||||
@ -1052,11 +1053,11 @@ def _save_ply(
|
||||
f.write(b"property float x\n")
|
||||
f.write(b"property float y\n")
|
||||
f.write(b"property float z\n")
|
||||
if verts_normals.numel() > 0:
|
||||
if verts_normals is not None:
|
||||
f.write(b"property float nx\n")
|
||||
f.write(b"property float ny\n")
|
||||
f.write(b"property float nz\n")
|
||||
if verts_colors.numel() > 0:
|
||||
if verts_colors is not None:
|
||||
f.write(b"property float red\n")
|
||||
f.write(b"property float green\n")
|
||||
f.write(b"property float blue\n")
|
||||
@ -1069,7 +1070,13 @@ def _save_ply(
|
||||
warnings.warn("Empty 'verts' provided")
|
||||
return
|
||||
|
||||
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy()
|
||||
verts_tensors = [verts]
|
||||
if verts_normals is not None:
|
||||
verts_tensors.append(verts_normals)
|
||||
if verts_colors is not None:
|
||||
verts_tensors.append(verts_colors)
|
||||
|
||||
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
|
||||
if ascii:
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
@ -1085,7 +1092,7 @@ def _save_ply(
|
||||
vert_data.tofile(f)
|
||||
|
||||
if faces is not None:
|
||||
faces_array = faces.detach().numpy()
|
||||
faces_array = faces.detach().cpu().numpy()
|
||||
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
@ -1125,12 +1132,6 @@ def save_ply(
|
||||
|
||||
"""
|
||||
|
||||
verts_normals = (
|
||||
torch.tensor([], dtype=torch.float32, device=verts.device)
|
||||
if verts_normals is None
|
||||
else verts_normals
|
||||
)
|
||||
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
@ -1143,16 +1144,18 @@ def save_ply(
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if len(verts_normals) and not (
|
||||
if (
|
||||
verts_normals is not None
|
||||
and len(verts_normals)
|
||||
and not (
|
||||
verts_normals.dim() == 2
|
||||
and verts_normals.size(1) == 3
|
||||
and verts_normals.size(0) == verts.size(0)
|
||||
)
|
||||
):
|
||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
verts_colors = torch.FloatTensor([])
|
||||
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
with _open_file(f, path_manager, "wb") as f:
|
||||
@ -1161,7 +1164,7 @@ def save_ply(
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
verts_normals=verts_normals,
|
||||
verts_colors=verts_colors,
|
||||
verts_colors=None,
|
||||
ascii=ascii,
|
||||
decimal_places=decimal_places,
|
||||
)
|
||||
@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces = load_ply(f=path, path_manager=path_manager)
|
||||
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)])
|
||||
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
textures = None
|
||||
if include_textures and verts_colors is not None:
|
||||
textures = TexturesVertex([verts_colors.to(device)])
|
||||
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)],
|
||||
faces=[faces.to(device)],
|
||||
textures=textures,
|
||||
)
|
||||
return mesh
|
||||
|
||||
def save(
|
||||
@ -1201,13 +1215,29 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
# TODO: normals are not saved. We only want to save them if they already exist.
|
||||
verts = data.verts_list()[0]
|
||||
faces = data.faces_list()[0]
|
||||
save_ply(
|
||||
f=path,
|
||||
|
||||
if isinstance(data.textures, TexturesVertex):
|
||||
mesh_verts_colors = data.textures.verts_features_list()[0]
|
||||
n_colors = mesh_verts_colors.shape[1]
|
||||
if n_colors == 3:
|
||||
verts_colors = mesh_verts_colors
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Texture will not be saved as it has {n_colors} colors, not 3."
|
||||
)
|
||||
verts_colors = None
|
||||
else:
|
||||
verts_colors = None
|
||||
|
||||
with _open_file(path, path_manager, "wb") as f:
|
||||
_save_ply(
|
||||
f=f,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
verts_colors=verts_colors,
|
||||
verts_normals=None,
|
||||
ascii=binary is False,
|
||||
decimal_places=decimal_places,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
return True
|
||||
|
||||
@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, features = _load_ply(
|
||||
f=path, path_manager=path_manager, return_vertex_colors=True
|
||||
)
|
||||
verts, faces, features = _load_ply(f=path, path_manager=path_manager)
|
||||
verts = verts.to(device)
|
||||
if features is None:
|
||||
pointcloud = Pointclouds(points=[verts])
|
||||
else:
|
||||
pointcloud = Pointclouds(points=[verts], features=[features.to(device)])
|
||||
if features is not None:
|
||||
features = [features.to(device)]
|
||||
|
||||
pointcloud = Pointclouds(points=[verts], features=features)
|
||||
return pointcloud
|
||||
|
||||
def save(
|
||||
@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
return False
|
||||
|
||||
points = data.points_list()[0]
|
||||
features = data.features_list()[0]
|
||||
features = data.features_packed()
|
||||
|
||||
with _open_file(path, path_manager, "wb") as f:
|
||||
_save_ply(
|
||||
f=f,
|
||||
verts=points,
|
||||
verts_colors=features,
|
||||
verts_normals=torch.FloatTensor([]),
|
||||
verts_normals=None,
|
||||
faces=None,
|
||||
ascii=binary is False,
|
||||
decimal_places=decimal_places,
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
import struct
|
||||
import unittest
|
||||
from io import BytesIO, StringIO
|
||||
@ -12,7 +13,8 @@ from common_testing import TestCaseMixin
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.io.ply_io import load_ply, save_ply
|
||||
from pytorch3d.structures import Pointclouds
|
||||
from pytorch3d.renderer.mesh import TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
|
||||
@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
):
|
||||
io.load_mesh(f3.name)
|
||||
|
||||
def test_save_too_many_colors(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||
vert_colors = torch.rand((4, 7))
|
||||
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
|
||||
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces],
|
||||
textures=texture_with_seven_colors,
|
||||
)
|
||||
|
||||
io = IO()
|
||||
msg = "Texture will not be saved as it has 7 colors, not 3."
|
||||
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
|
||||
with self.assertWarnsRegex(UserWarning, msg):
|
||||
io.save_mesh(mesh.cuda(), f.name)
|
||||
|
||||
def test_save_load_meshes(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||
vert_colors = torch.rand_like(verts)
|
||||
texture = TexturesVertex(verts_features=[vert_colors])
|
||||
|
||||
for do_textures in itertools.product([True, False]):
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces],
|
||||
textures=texture if do_textures else None,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
|
||||
io.save_mesh(mesh.cuda(), f.name)
|
||||
f.flush()
|
||||
mesh2 = io.load_mesh(f.name, device=device)
|
||||
self.assertEqual(mesh2.device, device)
|
||||
mesh2 = mesh2.cpu()
|
||||
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
|
||||
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
|
||||
if do_textures:
|
||||
self.assertIsInstance(mesh2.textures, TexturesVertex)
|
||||
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
|
||||
else:
|
||||
self.assertIsNone(mesh2.textures)
|
||||
|
||||
def test_save_ply_invalid_shapes(self):
|
||||
# Invalid vertices shape
|
||||
with self.assertRaises(ValueError) as error:
|
||||
|
Loading…
x
Reference in New Issue
Block a user