mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Support reading uv and uv map for ply format if texture_uv exists in ply file (#1100)
Summary: When the ply format looks as follows: ``` comment TextureFile ***.png element vertex 892 property double x property double y property double z property double nx property double ny property double nz property double texture_u property double texture_v ``` `MeshPlyFormat` class will read uv from the ply file and read the uv map as commented as TextureFile. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1100 Reviewed By: MichaelRamamonjisoa Differential Revision: D50885176 Pulled By: bottler fbshipit-source-id: be75b1ec9a17a1ed87dbcf846a9072ea967aec37
This commit is contained in:
		
							parent
							
								
									f4f2209271
								
							
						
					
					
						commit
						55638f3bae
					
				@ -10,6 +10,7 @@ This module implements utility functions for loading and saving
 | 
			
		||||
meshes and point clouds as PLY files.
 | 
			
		||||
"""
 | 
			
		||||
import itertools
 | 
			
		||||
import os
 | 
			
		||||
import struct
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
@ -21,8 +22,14 @@ from typing import List, Optional, Tuple
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from iopath.common.file_io import PathManager
 | 
			
		||||
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file, PathOrStr
 | 
			
		||||
from pytorch3d.renderer import TexturesVertex
 | 
			
		||||
from pytorch3d.io.utils import (
 | 
			
		||||
    _check_faces_indices,
 | 
			
		||||
    _make_tensor,
 | 
			
		||||
    _open_file,
 | 
			
		||||
    _read_image,
 | 
			
		||||
    PathOrStr,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer import TexturesUV, TexturesVertex
 | 
			
		||||
from pytorch3d.structures import Meshes, Pointclouds
 | 
			
		||||
 | 
			
		||||
from .pluggable_formats import (
 | 
			
		||||
@ -804,6 +811,7 @@ class _VertsColumnIndices:
 | 
			
		||||
    color_idxs: Optional[List[int]]
 | 
			
		||||
    color_scale: float
 | 
			
		||||
    normal_idxs: Optional[List[int]]
 | 
			
		||||
    texture_uv_idxs: Optional[List[int]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_verts_column_indices(
 | 
			
		||||
@ -827,6 +835,8 @@ def _get_verts_column_indices(
 | 
			
		||||
        property uchar red
 | 
			
		||||
        property uchar green
 | 
			
		||||
        property uchar blue
 | 
			
		||||
        property double texture_u
 | 
			
		||||
        property double texture_v
 | 
			
		||||
 | 
			
		||||
    then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
 | 
			
		||||
 | 
			
		||||
@ -839,6 +849,7 @@ def _get_verts_column_indices(
 | 
			
		||||
    point_idxs: List[Optional[int]] = [None, None, None]
 | 
			
		||||
    color_idxs: List[Optional[int]] = [None, None, None]
 | 
			
		||||
    normal_idxs: List[Optional[int]] = [None, None, None]
 | 
			
		||||
    texture_uv_idxs: List[Optional[int]] = [None, None]
 | 
			
		||||
    for i, prop in enumerate(vertex_head.properties):
 | 
			
		||||
        if prop.list_size_type is not None:
 | 
			
		||||
            raise ValueError("Invalid vertices in file: did not expect list.")
 | 
			
		||||
@ -851,6 +862,9 @@ def _get_verts_column_indices(
 | 
			
		||||
        for j, name in enumerate(["nx", "ny", "nz"]):
 | 
			
		||||
            if prop.name == name:
 | 
			
		||||
                normal_idxs[j] = i
 | 
			
		||||
        for j, name in enumerate(["texture_u", "texture_v"]):
 | 
			
		||||
            if prop.name == name:
 | 
			
		||||
                texture_uv_idxs[j] = i
 | 
			
		||||
    if None in point_idxs:
 | 
			
		||||
        raise ValueError("Invalid vertices in file.")
 | 
			
		||||
    color_scale = 1.0
 | 
			
		||||
@ -864,6 +878,7 @@ def _get_verts_column_indices(
 | 
			
		||||
        color_idxs=None if None in color_idxs else color_idxs,
 | 
			
		||||
        color_scale=color_scale,
 | 
			
		||||
        normal_idxs=None if None in normal_idxs else normal_idxs,
 | 
			
		||||
        texture_uv_idxs=None if None in texture_uv_idxs else texture_uv_idxs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -880,6 +895,7 @@ class _VertsData:
 | 
			
		||||
    verts: torch.Tensor
 | 
			
		||||
    verts_colors: Optional[torch.Tensor] = None
 | 
			
		||||
    verts_normals: Optional[torch.Tensor] = None
 | 
			
		||||
    verts_texture_uvs: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
 | 
			
		||||
@ -922,6 +938,7 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
 | 
			
		||||
 | 
			
		||||
    vertex_colors = None
 | 
			
		||||
    vertex_normals = None
 | 
			
		||||
    vertex_texture_uvs = None
 | 
			
		||||
 | 
			
		||||
    if len(vertex) == 1:
 | 
			
		||||
        # This is the case where the whole vertex element has one type,
 | 
			
		||||
@ -935,6 +952,10 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
 | 
			
		||||
            vertex_normals = torch.tensor(
 | 
			
		||||
                vertex[0][:, column_idxs.normal_idxs], dtype=torch.float32
 | 
			
		||||
            )
 | 
			
		||||
        if column_idxs.texture_uv_idxs is not None:
 | 
			
		||||
            vertex_texture_uvs = torch.tensor(
 | 
			
		||||
                vertex[0][:, column_idxs.texture_uv_idxs], dtype=torch.float32
 | 
			
		||||
            )
 | 
			
		||||
    else:
 | 
			
		||||
        # The vertex element is heterogeneous. It was read as several arrays,
 | 
			
		||||
        # part by part, where a part is a set of properties with the same type.
 | 
			
		||||
@ -973,11 +994,19 @@ def _get_verts(header: _PlyHeader, elements: dict) -> _VertsData:
 | 
			
		||||
            for axis in range(3):
 | 
			
		||||
                partnum, col = prop_to_partnum_col[column_idxs.normal_idxs[axis]]
 | 
			
		||||
                vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
 | 
			
		||||
 | 
			
		||||
        if column_idxs.texture_uv_idxs is not None:
 | 
			
		||||
            vertex_texture_uvs = torch.empty(
 | 
			
		||||
                size=(vertex_head.count, 2),
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
            )
 | 
			
		||||
            for axis in range(2):
 | 
			
		||||
                partnum, col = prop_to_partnum_col[column_idxs.texture_uv_idxs[axis]]
 | 
			
		||||
                vertex_texture_uvs.numpy()[:, axis] = vertex[partnum][:, col]
 | 
			
		||||
    return _VertsData(
 | 
			
		||||
        verts=verts,
 | 
			
		||||
        verts_colors=vertex_colors,
 | 
			
		||||
        verts_normals=vertex_normals,
 | 
			
		||||
        verts_texture_uvs=vertex_texture_uvs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -998,6 +1027,7 @@ class _PlyData:
 | 
			
		||||
    faces: Optional[torch.Tensor]
 | 
			
		||||
    verts_colors: Optional[torch.Tensor]
 | 
			
		||||
    verts_normals: Optional[torch.Tensor]
 | 
			
		||||
    verts_texture_uvs: Optional[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_ply(f, *, path_manager: PathManager) -> _PlyData:
 | 
			
		||||
@ -1358,8 +1388,27 @@ class MeshPlyFormat(MeshFormatInterpreter):
 | 
			
		||||
            faces = torch.zeros(0, 3, dtype=torch.int64)
 | 
			
		||||
 | 
			
		||||
        texture = None
 | 
			
		||||
        if include_textures and data.verts_colors is not None:
 | 
			
		||||
            texture = TexturesVertex([data.verts_colors.to(device)])
 | 
			
		||||
        if include_textures:
 | 
			
		||||
            if data.verts_colors is not None:
 | 
			
		||||
                texture = TexturesVertex([data.verts_colors.to(device)])
 | 
			
		||||
            elif data.verts_texture_uvs is not None:
 | 
			
		||||
                texture_file_path = None
 | 
			
		||||
                for comment in data.header.comments:
 | 
			
		||||
                    if "TextureFile" in comment:
 | 
			
		||||
                        given_texture_file = comment.split(" ")[-1]
 | 
			
		||||
                        texture_file_path = os.path.join(
 | 
			
		||||
                            os.path.dirname(str(path)), given_texture_file
 | 
			
		||||
                        )
 | 
			
		||||
                if texture_file_path is not None:
 | 
			
		||||
                    texture_map = _read_image(
 | 
			
		||||
                        texture_file_path, path_manager, format="RGB"
 | 
			
		||||
                    )
 | 
			
		||||
                    texture_map = torch.tensor(texture_map, dtype=torch.float32) / 255.0
 | 
			
		||||
                    texture = TexturesUV(
 | 
			
		||||
                        [texture_map.to(device)],
 | 
			
		||||
                        [faces.to(device)],
 | 
			
		||||
                        [data.verts_texture_uvs.to(device)],
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        verts_normals = None
 | 
			
		||||
        if data.verts_normals is not None:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								tests/data/uvs.ply
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								tests/data/uvs.ply
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
ply
 | 
			
		||||
format ascii 1.0
 | 
			
		||||
comment made by Greg Turk
 | 
			
		||||
comment this file is a cube
 | 
			
		||||
comment TextureFile test_nd_sphere.png
 | 
			
		||||
element vertex 8
 | 
			
		||||
property float x
 | 
			
		||||
property float y
 | 
			
		||||
property float z
 | 
			
		||||
property float texture_u
 | 
			
		||||
property float texture_v
 | 
			
		||||
element face 6
 | 
			
		||||
property list uchar int vertex_index
 | 
			
		||||
end_header
 | 
			
		||||
0 0 0 0 0
 | 
			
		||||
0 0 1 0.2 0.3
 | 
			
		||||
0 1 1 0.2 0.3
 | 
			
		||||
0 1 0 0.2 0.3
 | 
			
		||||
1 0 0 0.2 0.3
 | 
			
		||||
1 0 1 0.2 0.3
 | 
			
		||||
1 1 1 0.2 0.3
 | 
			
		||||
1 1 0 0.4 0.5
 | 
			
		||||
4 0 1 2 3
 | 
			
		||||
4 7 6 5 4
 | 
			
		||||
4 0 4 5 1
 | 
			
		||||
4 1 5 6 2
 | 
			
		||||
4 2 6 7 3
 | 
			
		||||
4 3 7 4 0
 | 
			
		||||
@ -20,10 +20,11 @@ from pytorch3d.renderer.mesh import TexturesVertex
 | 
			
		||||
from pytorch3d.structures import Meshes, Pointclouds
 | 
			
		||||
from pytorch3d.utils import torus
 | 
			
		||||
 | 
			
		||||
from .common_testing import TestCaseMixin
 | 
			
		||||
from .common_testing import get_tests_dir, TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
global_path_manager = PathManager()
 | 
			
		||||
DATA_DIR = get_tests_dir() / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_ply_raw(stream):
 | 
			
		||||
@ -778,6 +779,19 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                data["minus_ones"], [-1, 255, -1, 65535, -1, 4294967295]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_load_uvs(self):
 | 
			
		||||
        io = IO()
 | 
			
		||||
        mesh = io.load_mesh(DATA_DIR / "uvs.ply")
 | 
			
		||||
        self.assertEqual(mesh.textures.verts_uvs_padded().shape, (1, 8, 2))
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            mesh.textures.verts_uvs_padded()[0],
 | 
			
		||||
            torch.tensor([[0, 0]] + [[0.2, 0.3]] * 6 + [[0.4, 0.5]]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            mesh.textures.faces_uvs_padded().shape, mesh.faces_padded().shape
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(mesh.textures.maps_padded().shape, (1, 512, 512, 3))
 | 
			
		||||
 | 
			
		||||
    def test_bad_ply_syntax(self):
 | 
			
		||||
        """Some syntactically bad ply files."""
 | 
			
		||||
        lines = [
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user