diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index bce00bec..1d59b193 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -10,6 +10,7 @@ This module implements utility functions for loading and saving meshes and point clouds as PLY files. """ import itertools +import os import struct import sys import warnings @@ -21,8 +22,14 @@ from typing import List, Optional, Tuple import numpy as np import torch from iopath.common.file_io import PathManager -from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file, PathOrStr -from pytorch3d.renderer import TexturesVertex +from pytorch3d.io.utils import ( + _check_faces_indices, + _make_tensor, + _open_file, + _read_image, + PathOrStr, +) +from pytorch3d.renderer import TexturesUV, TexturesVertex from pytorch3d.structures import Meshes, Pointclouds from .pluggable_formats import ( @@ -804,6 +811,7 @@ class _VertsColumnIndices: color_idxs: Optional[List[int]] color_scale: float normal_idxs: Optional[List[int]] + texture_uv_idxs: Optional[List[int]] def _get_verts_column_indices( @@ -827,6 +835,8 @@ def _get_verts_column_indices( property uchar red property uchar green property uchar blue + property double texture_u + property double texture_v then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5]) @@ -839,6 +849,7 @@ def _get_verts_column_indices( point_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None] normal_idxs: List[Optional[int]] = [None, None, None] + texture_uv_idxs: List[Optional[int]] = [None, None] for i, prop in enumerate(vertex_head.properties): if prop.list_size_type is not None: raise ValueError("Invalid vertices in file: did not expect list.") @@ -851,6 +862,9 @@ def _get_verts_column_indices( for j, name in enumerate(["nx", "ny", "nz"]): if prop.name == name: normal_idxs[j] = i + for j, name in enumerate(["texture_u", "texture_v"]): + if prop.name == name: + texture_uv_idxs[j] = i if None in point_idxs: raise ValueError("Invalid vertices in file.") color_scale = 1.0 @@ -864,6 +878,7 @@ def _get_verts_column_indices( color_idxs=None if None in color_idxs else color_idxs, color_scale=color_scale, normal_idxs=None if None in normal_idxs else normal_idxs, + texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs, ) @@ -880,6 +895,7 @@ class _VertsData: verts: torch.Tensor verts_colors: Optional[torch.Tensor] = None verts_normals: Optional[torch.Tensor] = None + verts_texture_uvs: Optional[torch.Tensor] = None def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: @@ -922,6 +938,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: vertex_colors = None vertex_normals = None + vertex_texture_uvs = None if len(vertex) == 1: # This is the case where the whole vertex element has one type, @@ -935,6 +952,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: vertex_normals = torch.tensor( vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32 ) + if column_idxs.texture_uv_idxs is not None: + vertex_texture_uvs = torch.tensor( + vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32 + ) else: # The vertex element is heterogeneous. It was read as several arrays, # part by part, where a part is a set of properties with the same type. @@ -973,11 +994,19 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: for axis in range(3): partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]] vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] - + if column_idxs.texture_uv_idxs is not None: + vertex_texture_uvs = torch.empty( + size=(vertex_head.count, 2), + dtype=torch.float32, + ) + for axis in range(2): + partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]] + vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col] return _VertsData( verts=verts, verts_colors=vertex_colors, verts_normals=vertex_normals, + verts_texture_uvs=vertex_texture_uvs, ) @@ -998,6 +1027,7 @@ class _PlyData: faces: Optional[torch.Tensor] verts_colors: Optional[torch.Tensor] verts_normals: Optional[torch.Tensor] + verts_texture_uvs: Optional[torch.Tensor] def _load_ply(f, *, path_manager: PathManager) -> _PlyData: @@ -1358,8 +1388,27 @@ class MeshPlyFormat(MeshFormatInterpreter): faces = torch.zeros(0, 3, dtype=torch.int64) texture = None - if include_textures and data.verts_colors is not None: - texture = TexturesVertex([data.verts_colors.to(device)]) + if include_textures: + if data.verts_colors is not None: + texture = TexturesVertex([data.verts_colors.to(device)]) + elif data.verts_texture_uvs is not None: + texture_file_path = None + for comment in data.header.comments: + if "TextureFile" in comment: + given_texture_file = comment.split(" ")[-1] + texture_file_path = os.path.join( + os.path.dirname(str(path)), given_texture_file + ) + if texture_file_path is not None: + texture_map = _read_image( + texture_file_path, path_manager, format="RGB" + ) + texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.0 + texture = TexturesUV( + [texture_map.to(device)], + [faces.to(device)], + [data.verts_texture_uvs.to(device)], + ) verts_normals = None if data.verts_normals is not None: diff --git a/tests/data/uvs.ply b/tests/data/uvs.ply new file mode 100644 index 00000000..2e8532e1 --- /dev/null +++ b/tests/data/uvs.ply @@ -0,0 +1,28 @@ +ply +format ascii 1.0 +comment made by Greg Turk +comment this file is a cube +comment TextureFile test_nd_sphere.png +element vertex 8 +property float x +property float y +property float z +property float texture_u +property float texture_v +element face 6 +property list uchar int vertex_index +end_header +0 0 0 0 0 +0 0 1 0.2 0.3 +0 1 1 0.2 0.3 +0 1 0 0.2 0.3 +1 0 0 0.2 0.3 +1 0 1 0.2 0.3 +1 1 1 0.2 0.3 +1 1 0 0.4 0.5 +4 0 1 2 3 +4 7 6 5 4 +4 0 4 5 1 +4 1 5 6 2 +4 2 6 7 3 +4 3 7 4 0 diff --git a/tests/test_io_ply.py b/tests/test_io_ply.py index 4f32fea6..2b560cad 100644 --- a/tests/test_io_ply.py +++ b/tests/test_io_ply.py @@ -20,10 +20,11 @@ from pytorch3d.renderer.mesh import TexturesVertex from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.utils import torus -from .common_testing import TestCaseMixin +from .common_testing import get_tests_dir, TestCaseMixin global_path_manager = PathManager() +DATA_DIR = get_tests_dir() / "data" def _load_ply_raw(stream): @@ -778,6 +779,19 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295] ) + def test_load_uvs(self): + io = IO() + mesh = io.load_mesh(DATA_DIR / "uvs.ply") + self.assertEqual(mesh.textures.verts_uvs_padded().shape, (1, 8, 2)) + self.assertClose( + mesh.textures.verts_uvs_padded()[0], + torch.tensor([[0, 0]] + [[0.2, 0.3]] * 6 + [[0.4, 0.5]]), + ) + self.assertEqual( + mesh.textures.faces_uvs_padded().shape, mesh.faces_padded().shape + ) + self.assertEqual(mesh.textures.maps_padded().shape, (1, 512, 512, 3)) + def test_bad_ply_syntax(self): """Some syntactically bad ply files.""" lines = [