Support reading uv and uv map for ply format if texture_uv exists in ply file (#1100)

Summary:
When the ply format looks as follows:
  ```
comment TextureFile ***.png
element vertex 892
property double x
property double y
property double z
property double nx
property double ny
property double nz
property double texture_u
property double texture_v
```
`MeshPlyFormat` class will read uv from the ply file and read the uv map as commented as TextureFile.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1100

Reviewed By: MichaelRamamonjisoa

Differential Revision: D50885176

Pulled By: bottler

fbshipit-source-id: be75b1ec9a17a1ed87dbcf846a9072ea967aec37
This commit is contained in:
YangHai 2023-11-14 07:44:14 -08:00 committed by Facebook GitHub Bot
parent f4f2209271
commit 55638f3bae
3 changed files with 97 additions and 6 deletions

View File

@ -10,6 +10,7 @@ This module implements utility functions for loading and saving
meshes and point clouds as PLY files. meshes and point clouds as PLY files.
""" """
import itertools import itertools
import os
import struct import struct
import sys import sys
import warnings import warnings
@ -21,8 +22,14 @@ from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file, PathOrStr from pytorch3d.io.utils import (
from pytorch3d.renderer import TexturesVertex _check_faces_indices,
_make_tensor,
_open_file,
_read_image,
PathOrStr,
)
from pytorch3d.renderer import TexturesUV, TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from .pluggable_formats import ( from .pluggable_formats import (
@ -804,6 +811,7 @@ class _VertsColumnIndices:
color_idxs: Optional[List[int]] color_idxs: Optional[List[int]]
color_scale: float color_scale: float
normal_idxs: Optional[List[int]] normal_idxs: Optional[List[int]]
texture_uv_idxs: Optional[List[int]]
def _get_verts_column_indices( def _get_verts_column_indices(
@ -827,6 +835,8 @@ def _get_verts_column_indices(
property uchar red property uchar red
property uchar green property uchar green
property uchar blue property uchar blue
property double texture_u
property double texture_v
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5]) then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
@ -839,6 +849,7 @@ def _get_verts_column_indices(
point_idxs: List[Optional[int]] = [None, None, None] point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None]
normal_idxs: List[Optional[int]] = [None, None, None] normal_idxs: List[Optional[int]] = [None, None, None]
texture_uv_idxs: List[Optional[int]] = [None, None]
for i, prop in enumerate(vertex_head.properties): for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None: if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.") raise ValueError("Invalid vertices in file: did not expect list.")
@ -851,6 +862,9 @@ def _get_verts_column_indices(
for j, name in enumerate(["nx", "ny", "nz"]): for j, name in enumerate(["nx", "ny", "nz"]):
if prop.name == name: if prop.name == name:
normal_idxs[j] = i normal_idxs[j] = i
for j, name in enumerate(["texture_u", "texture_v"]):
if prop.name == name:
texture_uv_idxs[j] = i
if None in point_idxs: if None in point_idxs:
raise ValueError("Invalid vertices in file.") raise ValueError("Invalid vertices in file.")
color_scale = 1.0 color_scale = 1.0
@ -864,6 +878,7 @@ def _get_verts_column_indices(
color_idxs=None if None in color_idxs else color_idxs, color_idxs=None if None in color_idxs else color_idxs,
color_scale=color_scale, color_scale=color_scale,
normal_idxs=None if None in normal_idxs else normal_idxs, normal_idxs=None if None in normal_idxs else normal_idxs,
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
) )
@ -880,6 +895,7 @@ class _VertsData:
verts: torch.Tensor verts: torch.Tensor
verts_colors: Optional[torch.Tensor] = None verts_colors: Optional[torch.Tensor] = None
verts_normals: Optional[torch.Tensor] = None verts_normals: Optional[torch.Tensor] = None
verts_texture_uvs: Optional[torch.Tensor] = None
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData: def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
@ -922,6 +938,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
vertex_colors = None vertex_colors = None
vertex_normals = None vertex_normals = None
vertex_texture_uvs = None
if len(vertex) == 1: if len(vertex) == 1:
# This is the case where the whole vertex element has one type, # This is the case where the whole vertex element has one type,
@ -935,6 +952,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
vertex_normals = torch.tensor( vertex_normals = torch.tensor(
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32 vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
) )
if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.tensor(
vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32
)
else: else:
# The vertex element is heterogeneous. It was read as several arrays, # The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type. # part by part, where a part is a set of properties with the same type.
@ -973,11 +994,19 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
for axis in range(3): for axis in range(3):
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]] partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
if column_idxs.texture_uv_idxs is not None:
vertex_texture_uvs = torch.empty(
size=(vertex_head.count, 2),
dtype=torch.float32,
)
for axis in range(2):
partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]]
vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col]
return _VertsData( return _VertsData(
verts=verts, verts=verts,
verts_colors=vertex_colors, verts_colors=vertex_colors,
verts_normals=vertex_normals, verts_normals=vertex_normals,
verts_texture_uvs=vertex_texture_uvs,
) )
@ -998,6 +1027,7 @@ class _PlyData:
faces: Optional[torch.Tensor] faces: Optional[torch.Tensor]
verts_colors: Optional[torch.Tensor] verts_colors: Optional[torch.Tensor]
verts_normals: Optional[torch.Tensor] verts_normals: Optional[torch.Tensor]
verts_texture_uvs: Optional[torch.Tensor]
def _load_ply(f, *, path_manager: PathManager) -> _PlyData: def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
@ -1358,8 +1388,27 @@ class MeshPlyFormat(MeshFormatInterpreter):
faces = torch.zeros(0, 3, dtype=torch.int64) faces = torch.zeros(0, 3, dtype=torch.int64)
texture = None texture = None
if include_textures and data.verts_colors is not None: if include_textures:
if data.verts_colors is not None:
texture = TexturesVertex([data.verts_colors.to(device)]) texture = TexturesVertex([data.verts_colors.to(device)])
elif data.verts_texture_uvs is not None:
texture_file_path = None
for comment in data.header.comments:
if "TextureFile" in comment:
given_texture_file = comment.split(" ")[-1]
texture_file_path = os.path.join(
os.path.dirname(str(path)), given_texture_file
)
if texture_file_path is not None:
texture_map = _read_image(
texture_file_path, path_manager, format="RGB"
)
texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.0
texture = TexturesUV(
[texture_map.to(device)],
[faces.to(device)],
[data.verts_texture_uvs.to(device)],
)
verts_normals = None verts_normals = None
if data.verts_normals is not None: if data.verts_normals is not None:

28
tests/data/uvs.ply Normal file
View File

@ -0,0 +1,28 @@
ply
format ascii 1.0
comment made by Greg Turk
comment this file is a cube
comment TextureFile test_nd_sphere.png
element vertex 8
property float x
property float y
property float z
property float texture_u
property float texture_v
element face 6
property list uchar int vertex_index
end_header
0 0 0 0 0
0 0 1 0.2 0.3
0 1 1 0.2 0.3
0 1 0 0.2 0.3
1 0 0 0.2 0.3
1 0 1 0.2 0.3
1 1 1 0.2 0.3
1 1 0 0.4 0.5
4 0 1 2 3
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0

View File

@ -20,10 +20,11 @@ from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import torus from pytorch3d.utils import torus
from .common_testing import TestCaseMixin from .common_testing import get_tests_dir, TestCaseMixin
global_path_manager = PathManager() global_path_manager = PathManager()
DATA_DIR = get_tests_dir() / "data"
def _load_ply_raw(stream): def _load_ply_raw(stream):
@ -778,6 +779,19 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295] data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295]
) )
def test_load_uvs(self):
io = IO()
mesh = io.load_mesh(DATA_DIR / "uvs.ply")
self.assertEqual(mesh.textures.verts_uvs_padded().shape, (1, 8, 2))
self.assertClose(
mesh.textures.verts_uvs_padded()[0],
torch.tensor([[0, 0]] + [[0.2, 0.3]] * 6 + [[0.4, 0.5]]),
)
self.assertEqual(
mesh.textures.faces_uvs_padded().shape, mesh.faces_padded().shape
)
self.assertEqual(mesh.textures.maps_padded().shape, (1, 512, 512, 3))
def test_bad_ply_syntax(self): def test_bad_ply_syntax(self):
"""Some syntactically bad ply files.""" """Some syntactically bad ply files."""
lines = [ lines = [