mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Support reading uv and uv map for ply format if texture_uv exists in ply file (#1100)
Summary: When the ply format looks as follows: ``` comment TextureFile ***.png element vertex 892 property double x property double y property double z property double nx property double ny property double nz property double texture_u property double texture_v ``` `MeshPlyFormat` class will read uv from the ply file and read the uv map as commented as TextureFile. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1100 Reviewed By: MichaelRamamonjisoa Differential Revision: D50885176 Pulled By: bottler fbshipit-source-id: be75b1ec9a17a1ed87dbcf846a9072ea967aec37
This commit is contained in:
parent
f4f2209271
commit
55638f3bae
@ -10,6 +10,7 @@ This module implements utility functions for loading and saving
|
||||
meshes and point clouds as PLY files.
|
||||
"""
|
||||
import itertools
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
@ -21,8 +22,14 @@ from typing import List, Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file, PathOrStr
|
||||
from pytorch3d.renderer import TexturesVertex
|
||||
from pytorch3d.io.utils import (
|
||||
_check_faces_indices,
|
||||
_make_tensor,
|
||||
_open_file,
|
||||
_read_image,
|
||||
PathOrStr,
|
||||
)
|
||||
from pytorch3d.renderer import TexturesUV, TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
|
||||
from .pluggable_formats import (
|
||||
@ -804,6 +811,7 @@ class _VertsColumnIndices:
|
||||
color_idxs: Optional[List[int]]
|
||||
color_scale: float
|
||||
normal_idxs: Optional[List[int]]
|
||||
texture_uv_idxs: Optional[List[int]]
|
||||
|
||||
|
||||
def _get_verts_column_indices(
|
||||
@ -827,6 +835,8 @@ def _get_verts_column_indices(
|
||||
property uchar red
|
||||
property uchar green
|
||||
property uchar blue
|
||||
property double texture_u
|
||||
property double texture_v
|
||||
|
||||
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
|
||||
|
||||
@ -839,6 +849,7 @@ def _get_verts_column_indices(
|
||||
point_idxs: List[Optional[int]] = [None, None, None]
|
||||
color_idxs: List[Optional[int]] = [None, None, None]
|
||||
normal_idxs: List[Optional[int]] = [None, None, None]
|
||||
texture_uv_idxs: List[Optional[int]] = [None, None]
|
||||
for i, prop in enumerate(vertex_head.properties):
|
||||
if prop.list_size_type is not None:
|
||||
raise ValueError("Invalid vertices in file: did not expect list.")
|
||||
@ -851,6 +862,9 @@ def _get_verts_column_indices(
|
||||
for j, name in enumerate(["nx", "ny", "nz"]):
|
||||
if prop.name == name:
|
||||
normal_idxs[j] = i
|
||||
for j, name in enumerate(["texture_u", "texture_v"]):
|
||||
if prop.name == name:
|
||||
texture_uv_idxs[j] = i
|
||||
if None in point_idxs:
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
color_scale = 1.0
|
||||
@ -864,6 +878,7 @@ def _get_verts_column_indices(
|
||||
color_idxs=None if None in color_idxs else color_idxs,
|
||||
color_scale=color_scale,
|
||||
normal_idxs=None if None in normal_idxs else normal_idxs,
|
||||
texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
|
||||
)
|
||||
|
||||
|
||||
@ -880,6 +895,7 @@ class _VertsData:
|
||||
verts: torch.Tensor
|
||||
verts_colors: Optional[torch.Tensor] = None
|
||||
verts_normals: Optional[torch.Tensor] = None
|
||||
verts_texture_uvs: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||
@ -922,6 +938,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||
|
||||
vertex_colors = None
|
||||
vertex_normals = None
|
||||
vertex_texture_uvs = None
|
||||
|
||||
if len(vertex) == 1:
|
||||
# This is the case where the whole vertex element has one type,
|
||||
@ -935,6 +952,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||
vertex_normals = torch.tensor(
|
||||
vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
|
||||
)
|
||||
if column_idxs.texture_uv_idxs is not None:
|
||||
vertex_texture_uvs = torch.tensor(
|
||||
vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
# The vertex element is heterogeneous. It was read as several arrays,
|
||||
# part by part, where a part is a set of properties with the same type.
|
||||
@ -973,11 +994,19 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
|
||||
for axis in range(3):
|
||||
partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
|
||||
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
|
||||
if column_idxs.texture_uv_idxs is not None:
|
||||
vertex_texture_uvs = torch.empty(
|
||||
size=(vertex_head.count, 2),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for axis in range(2):
|
||||
partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]]
|
||||
vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
return _VertsData(
|
||||
verts=verts,
|
||||
verts_colors=vertex_colors,
|
||||
verts_normals=vertex_normals,
|
||||
verts_texture_uvs=vertex_texture_uvs,
|
||||
)
|
||||
|
||||
|
||||
@ -998,6 +1027,7 @@ class _PlyData:
|
||||
faces: Optional[torch.Tensor]
|
||||
verts_colors: Optional[torch.Tensor]
|
||||
verts_normals: Optional[torch.Tensor]
|
||||
verts_texture_uvs: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
|
||||
@ -1358,8 +1388,27 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
texture = None
|
||||
if include_textures and data.verts_colors is not None:
|
||||
texture = TexturesVertex([data.verts_colors.to(device)])
|
||||
if include_textures:
|
||||
if data.verts_colors is not None:
|
||||
texture = TexturesVertex([data.verts_colors.to(device)])
|
||||
elif data.verts_texture_uvs is not None:
|
||||
texture_file_path = None
|
||||
for comment in data.header.comments:
|
||||
if "TextureFile" in comment:
|
||||
given_texture_file = comment.split(" ")[-1]
|
||||
texture_file_path = os.path.join(
|
||||
os.path.dirname(str(path)), given_texture_file
|
||||
)
|
||||
if texture_file_path is not None:
|
||||
texture_map = _read_image(
|
||||
texture_file_path, path_manager, format="RGB"
|
||||
)
|
||||
texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.0
|
||||
texture = TexturesUV(
|
||||
[texture_map.to(device)],
|
||||
[faces.to(device)],
|
||||
[data.verts_texture_uvs.to(device)],
|
||||
)
|
||||
|
||||
verts_normals = None
|
||||
if data.verts_normals is not None:
|
||||
|
28
tests/data/uvs.ply
Normal file
28
tests/data/uvs.ply
Normal file
@ -0,0 +1,28 @@
|
||||
ply
|
||||
format ascii 1.0
|
||||
comment made by Greg Turk
|
||||
comment this file is a cube
|
||||
comment TextureFile test_nd_sphere.png
|
||||
element vertex 8
|
||||
property float x
|
||||
property float y
|
||||
property float z
|
||||
property float texture_u
|
||||
property float texture_v
|
||||
element face 6
|
||||
property list uchar int vertex_index
|
||||
end_header
|
||||
0 0 0 0 0
|
||||
0 0 1 0.2 0.3
|
||||
0 1 1 0.2 0.3
|
||||
0 1 0 0.2 0.3
|
||||
1 0 0 0.2 0.3
|
||||
1 0 1 0.2 0.3
|
||||
1 1 1 0.2 0.3
|
||||
1 1 0 0.4 0.5
|
||||
4 0 1 2 3
|
||||
4 7 6 5 4
|
||||
4 0 4 5 1
|
||||
4 1 5 6 2
|
||||
4 2 6 7 3
|
||||
4 3 7 4 0
|
@ -20,10 +20,11 @@ from pytorch3d.renderer.mesh import TexturesVertex
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
from .common_testing import TestCaseMixin
|
||||
from .common_testing import get_tests_dir, TestCaseMixin
|
||||
|
||||
|
||||
global_path_manager = PathManager()
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
|
||||
|
||||
def _load_ply_raw(stream):
|
||||
@ -778,6 +779,19 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295]
|
||||
)
|
||||
|
||||
def test_load_uvs(self):
|
||||
io = IO()
|
||||
mesh = io.load_mesh(DATA_DIR / "uvs.ply")
|
||||
self.assertEqual(mesh.textures.verts_uvs_padded().shape, (1, 8, 2))
|
||||
self.assertClose(
|
||||
mesh.textures.verts_uvs_padded()[0],
|
||||
torch.tensor([[0, 0]] + [[0.2, 0.3]] * 6 + [[0.4, 0.5]]),
|
||||
)
|
||||
self.assertEqual(
|
||||
mesh.textures.faces_uvs_padded().shape, mesh.faces_padded().shape
|
||||
)
|
||||
self.assertEqual(mesh.textures.maps_padded().shape, (1, 512, 512, 3))
|
||||
|
||||
def test_bad_ply_syntax(self):
|
||||
"""Some syntactically bad ply files."""
|
||||
lines = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user