PLY TexturesVertex loading

Summary:
Include TexturesVertex colors when loading and saving Meshes to PLY files.

A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data.

Reviewed By: gkioxari

Differential Revision: D27765260

fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
This commit is contained in:
Jeremy Reizenstein 2021-05-04 05:35:24 -07:00 committed by Facebook GitHub Bot
parent 097b0ef2c6
commit 6c3fe952d1
2 changed files with 135 additions and 53 deletions

View File

@ -20,6 +20,7 @@ import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from .pluggable_formats import ( from .pluggable_formats import (
@ -66,7 +67,7 @@ class _PlyElementType:
def __init__(self, name: str, count: int): def __init__(self, name: str, count: int):
self.name = name self.name = name
self.count = count self.count = count
self.properties = [] self.properties: List[_Property] = []
def add_property( def add_property(
self, name: str, data_type: str, list_size_type: Optional[str] = None self, name: str, data_type: str, list_size_type: Optional[str] = None
@ -142,7 +143,7 @@ class _PlyHeader:
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]: if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
raise ValueError("Invalid file header.") raise ValueError("Invalid file header.")
seen_format = False seen_format = False
self.elements = [] self.elements: List[_PlyElementType] = []
self.obj_info = [] self.obj_info = []
while True: while True:
line = f.readline() line = f.readline()
@ -891,8 +892,8 @@ def _get_verts(
def _load_ply( def _load_ply(
f, *, path_manager: PathManager, return_vertex_colors: bool = False f, *, path_manager: PathManager
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Load the data from a .ply file. Load the data from a .ply file.
@ -903,12 +904,11 @@ def _load_ply(
ply format, then a text stream is not supported. ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases. It is easiest to use a binary stream in all cases.
path_manager: PathManager for loading if f is a str. path_manager: PathManager for loading if f is a str.
return_vertex_colors: whether to return vertex colors.
Returns: Returns:
verts: FloatTensor of shape (V, 3). verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3). faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3), only if requested vertex_colors: None or FloatTensor of shape (V, 3).
""" """
header, elements = _load_ply_raw(f, path_manager=path_manager) header, elements = _load_ply_raw(f, path_manager=path_manager)
@ -950,16 +950,17 @@ def _load_ply(
if faces is not None: if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
if return_vertex_colors: return verts, faces, vertex_colors
return verts, faces, vertex_colors
return verts, faces, None
def load_ply( def load_ply(
f, *, path_manager: Optional[PathManager] = None f, *, path_manager: Optional[PathManager] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Load the data from a .ply file. Load the verts and faces from a .ply file.
Note that the preferred way to load data from such a file
is to use the IO.load_mesh and IO.load_pointcloud functions,
which can read more of the data.
Example .ply file format: Example .ply file format:
@ -1016,8 +1017,8 @@ def _save_ply(
*, *,
verts: torch.Tensor, verts: torch.Tensor,
faces: Optional[torch.LongTensor], faces: Optional[torch.LongTensor],
verts_normals: torch.Tensor, verts_normals: Optional[torch.Tensor],
verts_colors: torch.Tensor, verts_colors: Optional[torch.Tensor],
ascii: bool, ascii: bool,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
) -> None: ) -> None:
@ -1029,16 +1030,16 @@ def _save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
ascii: (bool) whether to use the ascii ply format. ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True. decimal_places: Number of decimal places for saving if ascii=True.
""" """
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
if faces is not None: assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) assert verts_normals is None or (
assert not len(verts_normals) or (
verts_normals.dim() == 2 and verts_normals.size(1) == 3 verts_normals.dim() == 2 and verts_normals.size(1) == 3
) )
assert not len(verts_colors) or ( assert verts_colors is None or (
verts_colors.dim() == 2 and verts_colors.size(1) == 3 verts_colors.dim() == 2 and verts_colors.size(1) == 3
) )
@ -1052,11 +1053,11 @@ def _save_ply(
f.write(b"property float x\n") f.write(b"property float x\n")
f.write(b"property float y\n") f.write(b"property float y\n")
f.write(b"property float z\n") f.write(b"property float z\n")
if verts_normals.numel() > 0: if verts_normals is not None:
f.write(b"property float nx\n") f.write(b"property float nx\n")
f.write(b"property float ny\n") f.write(b"property float ny\n")
f.write(b"property float nz\n") f.write(b"property float nz\n")
if verts_colors.numel() > 0: if verts_colors is not None:
f.write(b"property float red\n") f.write(b"property float red\n")
f.write(b"property float green\n") f.write(b"property float green\n")
f.write(b"property float blue\n") f.write(b"property float blue\n")
@ -1069,7 +1070,13 @@ def _save_ply(
warnings.warn("Empty 'verts' provided") warnings.warn("Empty 'verts' provided")
return return
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy() verts_tensors = [verts]
if verts_normals is not None:
verts_tensors.append(verts_normals)
if verts_colors is not None:
verts_tensors.append(verts_colors)
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
if ascii: if ascii:
if decimal_places is None: if decimal_places is None:
float_str = "%f" float_str = "%f"
@ -1085,7 +1092,7 @@ def _save_ply(
vert_data.tofile(f) vert_data.tofile(f)
if faces is not None: if faces is not None:
faces_array = faces.detach().numpy() faces_array = faces.detach().cpu().numpy()
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
@ -1125,12 +1132,6 @@ def save_ply(
""" """
verts_normals = (
torch.tensor([], dtype=torch.float32, device=verts.device)
if verts_normals is None
else verts_normals
)
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
@ -1143,16 +1144,18 @@ def save_ply(
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
if len(verts_normals) and not ( if (
verts_normals.dim() == 2 verts_normals is not None
and verts_normals.size(1) == 3 and len(verts_normals)
and verts_normals.size(0) == verts.size(0) and not (
verts_normals.dim() == 2
and verts_normals.size(1) == 3
and verts_normals.size(0) == verts.size(0)
)
): ):
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
verts_colors = torch.FloatTensor([])
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f: with _open_file(f, path_manager, "wb") as f:
@ -1161,7 +1164,7 @@ def save_ply(
verts=verts, verts=verts,
faces=faces, faces=faces,
verts_normals=verts_normals, verts_normals=verts_normals,
verts_colors=verts_colors, verts_colors=None,
ascii=ascii, ascii=ascii,
decimal_places=decimal_places, decimal_places=decimal_places,
) )
@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces = load_ply(f=path, path_manager=path_manager) verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)]) if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)
textures = None
if include_textures and verts_colors is not None:
textures = TexturesVertex([verts_colors.to(device)])
mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures,
)
return mesh return mesh
def save( def save(
@ -1201,14 +1215,30 @@ class MeshPlyFormat(MeshFormatInterpreter):
# TODO: normals are not saved. We only want to save them if they already exist. # TODO: normals are not saved. We only want to save them if they already exist.
verts = data.verts_list()[0] verts = data.verts_list()[0]
faces = data.faces_list()[0] faces = data.faces_list()[0]
save_ply(
f=path, if isinstance(data.textures, TexturesVertex):
verts=verts, mesh_verts_colors = data.textures.verts_features_list()[0]
faces=faces, n_colors = mesh_verts_colors.shape[1]
ascii=binary is False, if n_colors == 3:
decimal_places=decimal_places, verts_colors = mesh_verts_colors
path_manager=path_manager, else:
) warnings.warn(
f"Texture will not be saved as it has {n_colors} colors, not 3."
)
verts_colors = None
else:
verts_colors = None
with _open_file(path, path_manager, "wb") as f:
_save_ply(
f=f,
verts=verts,
faces=faces,
verts_colors=verts_colors,
verts_normals=None,
ascii=binary is False,
decimal_places=decimal_places,
)
return True return True
@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, features = _load_ply( verts, faces, features = _load_ply(f=path, path_manager=path_manager)
f=path, path_manager=path_manager, return_vertex_colors=True
)
verts = verts.to(device) verts = verts.to(device)
if features is None: if features is not None:
pointcloud = Pointclouds(points=[verts]) features = [features.to(device)]
else:
pointcloud = Pointclouds(points=[verts], features=[features.to(device)]) pointcloud = Pointclouds(points=[verts], features=features)
return pointcloud return pointcloud
def save( def save(
@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
return False return False
points = data.points_list()[0] points = data.points_list()[0]
features = data.features_list()[0] features = data.features_packed()
with _open_file(path, path_manager, "wb") as f: with _open_file(path, path_manager, "wb") as f:
_save_ply( _save_ply(
f=f, f=f,
verts=points, verts=points,
verts_colors=features, verts_colors=features,
verts_normals=torch.FloatTensor([]), verts_normals=None,
faces=None, faces=None,
ascii=binary is False, ascii=binary is False,
decimal_places=decimal_places, decimal_places=decimal_places,

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
import struct import struct
import unittest import unittest
from io import BytesIO, StringIO from io import BytesIO, StringIO
@ -12,7 +13,8 @@ from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO from pytorch3d.io import IO
from pytorch3d.io.ply_io import load_ply, save_ply from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.structures import Pointclouds from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import torus from pytorch3d.utils import torus
@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
): ):
io.load_mesh(f3.name) io.load_mesh(f3.name)
def test_save_too_many_colors(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand((4, 7))
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture_with_seven_colors,
)
io = IO()
msg = "Texture will not be saved as it has 7 colors, not 3."
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
with self.assertWarnsRegex(UserWarning, msg):
io.save_mesh(mesh.cuda(), f.name)
def test_save_load_meshes(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand_like(verts)
texture = TexturesVertex(verts_features=[vert_colors])
for do_textures in itertools.product([True, False]):
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture if do_textures else None,
)
device = torch.device("cuda:0")
io = IO()
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
io.save_mesh(mesh.cuda(), f.name)
f.flush()
mesh2 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh2.device, device)
mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
if do_textures:
self.assertIsInstance(mesh2.textures, TexturesVertex)
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
else:
self.assertIsNone(mesh2.textures)
def test_save_ply_invalid_shapes(self): def test_save_ply_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error: