mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Experimental glTF reading
Summary: Experimental data loader for taking the default scene from a GLB file and converting it to a single mesh in PyTorch3D. Reviewed By: nikhilaravi Differential Revision: D25900167 fbshipit-source-id: bff22ac00298b83a0bd071ae5c8923561e1d81d7
This commit is contained in:
parent
0e85652f07
commit
ed6983ea84
@ -22,3 +22,13 @@ and to save a pointcloud you might do
|
|||||||
pcl = Pointclouds(...)
|
pcl = Pointclouds(...)
|
||||||
IO().save_point_cloud(pcl, "output_pointcloud.obj")
|
IO().save_point_cloud(pcl, "output_pointcloud.obj")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For meshes, this supports OBJ, PLY and OFF files.
|
||||||
|
|
||||||
|
For pointclouds, this supports PLY files.
|
||||||
|
|
||||||
|
In addition, there is experimental support for loading meshes from
|
||||||
|
[glTF 2 assets](https://github.com/KhronosGroup/glTF/tree/master/specification/2.0)
|
||||||
|
stored either in a GLB container file or a glTF JSON file with embedded binary data.
|
||||||
|
This must be enabled explicitly, as described in
|
||||||
|
`pytorch3d/io/experimental_gltf_io.ply`.
|
||||||
|
572
pytorch3d/io/experimental_gltf_io.py
Normal file
572
pytorch3d/io/experimental_gltf_io.py
Normal file
@ -0,0 +1,572 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module implements loading meshes from glTF 2 assets stored in a
|
||||||
|
GLB container file or a glTF JSON file with embedded binary data.
|
||||||
|
It is experimental.
|
||||||
|
|
||||||
|
The module provides a MeshFormatInterpreter called
|
||||||
|
MeshGlbFormat which must be used explicitly.
|
||||||
|
e.g.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pytorch3d.io import IO
|
||||||
|
from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
|
||||||
|
|
||||||
|
io = IO()
|
||||||
|
io.register_meshes_format(MeshGlbFormat())
|
||||||
|
io.load_mesh(...)
|
||||||
|
|
||||||
|
This implementation is quite restricted in what it supports.
|
||||||
|
|
||||||
|
- It does not try to validate the input against the standard.
|
||||||
|
- It loads the default scene only.
|
||||||
|
- Only triangulated geometry is supported.
|
||||||
|
- The geometry of all meshes of the entire scene is aggregated into a single mesh.
|
||||||
|
Use `load_meshes()` instead to get un-aggregated (but transformed) ones.
|
||||||
|
- All material properties are ignored except for either vertex color, baseColorTexture
|
||||||
|
or baseColorFactor. If available, one of these (in this order) is exclusively
|
||||||
|
used which does not match the semantics of the standard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import struct
|
||||||
|
import warnings
|
||||||
|
from base64 import b64decode
|
||||||
|
from collections import deque
|
||||||
|
from enum import IntEnum
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from iopath.common.file_io import PathManager
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch3d.io.utils import _open_file
|
||||||
|
from pytorch3d.renderer.mesh import TexturesBase, TexturesUV, TexturesVertex
|
||||||
|
from pytorch3d.structures import Meshes, join_meshes_as_scene
|
||||||
|
from pytorch3d.transforms import Transform3d, quaternion_to_matrix
|
||||||
|
|
||||||
|
from .pluggable_formats import MeshFormatInterpreter, endswith
|
||||||
|
|
||||||
|
|
||||||
|
_GLTF_MAGIC = 0x46546C67
|
||||||
|
_JSON_CHUNK_TYPE = 0x4E4F534A
|
||||||
|
_BINARY_CHUNK_TYPE = 0x004E4942
|
||||||
|
_DATA_URI_PREFIX = "data:application/octet-stream;base64,"
|
||||||
|
|
||||||
|
|
||||||
|
class _PrimitiveMode(IntEnum):
|
||||||
|
POINTS = 0
|
||||||
|
LINES = 1
|
||||||
|
LINE_LOOP = 2
|
||||||
|
LINE_STRIP = 3
|
||||||
|
TRIANGLES = 4
|
||||||
|
TRIANGLE_STRIP = 5
|
||||||
|
TRIANGLE_FAN = 6
|
||||||
|
|
||||||
|
|
||||||
|
class _ComponentType(IntEnum):
|
||||||
|
BYTE = 5120
|
||||||
|
UNSIGNED_BYTE = 5121
|
||||||
|
SHORT = 5122
|
||||||
|
UNSIGNED_SHORT = 5123
|
||||||
|
UNSIGNED_INT = 5125
|
||||||
|
FLOAT = 5126
|
||||||
|
|
||||||
|
|
||||||
|
_ITEM_TYPES: Dict[int, Any] = {
|
||||||
|
5120: np.int8,
|
||||||
|
5121: np.uint8,
|
||||||
|
5122: np.int16,
|
||||||
|
5123: np.uint16,
|
||||||
|
5125: np.uint32,
|
||||||
|
5126: np.float32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_ElementShape = Union[Tuple[int], Tuple[int, int]]
|
||||||
|
_ELEMENT_SHAPES: Dict[str, _ElementShape] = {
|
||||||
|
"SCALAR": (1,),
|
||||||
|
"VEC2": (2,),
|
||||||
|
"VEC3": (3,),
|
||||||
|
"VEC4": (4,),
|
||||||
|
"MAT2": (2, 2),
|
||||||
|
"MAT3": (3, 3),
|
||||||
|
"MAT4": (4, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
|
||||||
|
header = stream.read(12)
|
||||||
|
magic, version, length = struct.unpack("<III", header)
|
||||||
|
|
||||||
|
if magic != _GLTF_MAGIC:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return version, length
|
||||||
|
|
||||||
|
|
||||||
|
def _read_chunks(
|
||||||
|
stream: BinaryIO, length: int
|
||||||
|
) -> Optional[Tuple[Dict[str, Any], np.ndarray]]:
|
||||||
|
"""
|
||||||
|
Get the json header and the binary data from a
|
||||||
|
GLB file.
|
||||||
|
"""
|
||||||
|
json_data = None
|
||||||
|
binary_data = None
|
||||||
|
|
||||||
|
while stream.tell() < length:
|
||||||
|
chunk_header = stream.read(8)
|
||||||
|
chunk_length, chunk_type = struct.unpack("<II", chunk_header)
|
||||||
|
chunk_data = stream.read(chunk_length)
|
||||||
|
if chunk_type == _JSON_CHUNK_TYPE:
|
||||||
|
json_data = json.loads(chunk_data)
|
||||||
|
elif chunk_type == _BINARY_CHUNK_TYPE:
|
||||||
|
binary_data = chunk_data
|
||||||
|
else:
|
||||||
|
warnings.warn("Unsupported chunk type")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if json_data is None:
|
||||||
|
raise ValueError("Missing json header")
|
||||||
|
|
||||||
|
if binary_data is not None:
|
||||||
|
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
|
||||||
|
|
||||||
|
return json_data, binary_data
|
||||||
|
|
||||||
|
|
||||||
|
def _make_node_transform(node: Dict[str, Any]) -> Transform3d:
|
||||||
|
"""
|
||||||
|
Convert a transform from the json data in to a PyTorch3D
|
||||||
|
Transform3d format.
|
||||||
|
"""
|
||||||
|
array = node.get("matrix")
|
||||||
|
if array is not None: # Stored in column-major order
|
||||||
|
M = np.array(array, dtype=np.float32).reshape(4, 4, order="F")
|
||||||
|
return Transform3d(matrix=torch.from_numpy(M))
|
||||||
|
|
||||||
|
out = Transform3d()
|
||||||
|
|
||||||
|
# Given some of (scale/rotation/translation), we do them in that order to
|
||||||
|
# get points in to the world space.
|
||||||
|
# See https://github.com/KhronosGroup/glTF/issues/743 .
|
||||||
|
|
||||||
|
array = node.get("scale", None)
|
||||||
|
if array is not None:
|
||||||
|
scale_vector = torch.FloatTensor(array)
|
||||||
|
out = out.scale(scale_vector[None])
|
||||||
|
|
||||||
|
# Rotation quaternion (x, y, z, w) where w is the scalar
|
||||||
|
array = node.get("rotation", None)
|
||||||
|
if array is not None:
|
||||||
|
x, y, z, w = array
|
||||||
|
# We negate w. This is equivalent to inverting the rotation.
|
||||||
|
# This is needed as quaternion_to_matrix makes a matrix which
|
||||||
|
# operates on column vectors, whereas Transform3d wants a
|
||||||
|
# matrix which operates on row vectors.
|
||||||
|
rotation_quaternion = torch.FloatTensor([-w, x, y, z])
|
||||||
|
rotation_matrix = quaternion_to_matrix(rotation_quaternion)
|
||||||
|
out = out.rotate(R=rotation_matrix)
|
||||||
|
|
||||||
|
array = node.get("translation", None)
|
||||||
|
if array is not None:
|
||||||
|
translation_vector = torch.FloatTensor(array)
|
||||||
|
out = out.translate(x=translation_vector[None])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class _GLTFLoader:
|
||||||
|
def __init__(self, stream: BinaryIO):
|
||||||
|
self._json_data = None
|
||||||
|
# Map from buffer index to (decoded) binary data
|
||||||
|
self._binary_data = {}
|
||||||
|
|
||||||
|
version_and_length = _read_header(stream)
|
||||||
|
if version_and_length is None: # GLTF
|
||||||
|
stream.seek(0)
|
||||||
|
json_data = json.load(stream)
|
||||||
|
else: # GLB
|
||||||
|
version, length = version_and_length
|
||||||
|
if version != 2:
|
||||||
|
warnings.warn("Unsupported version")
|
||||||
|
return
|
||||||
|
json_and_binary_data = _read_chunks(stream, length)
|
||||||
|
if json_and_binary_data is None:
|
||||||
|
raise ValueError("Data not found")
|
||||||
|
json_data, binary_data = json_and_binary_data
|
||||||
|
self._binary_data[0] = binary_data
|
||||||
|
|
||||||
|
self._json_data = json_data
|
||||||
|
self._accessors = json_data.get("accessors", [])
|
||||||
|
self._buffer_views = json_data.get("bufferViews", [])
|
||||||
|
self._buffers = json_data.get("buffers", [])
|
||||||
|
self._texture_map_images = {}
|
||||||
|
|
||||||
|
def _access_image(self, image_index: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get the data for an image from the file. This is only called
|
||||||
|
by _get_texture_map_image which caches it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_json = self._json_data["images"][image_index]
|
||||||
|
buffer_view = self._buffer_views[image_json["bufferView"]]
|
||||||
|
if "byteStride" in buffer_view:
|
||||||
|
raise NotImplementedError("strided buffer views")
|
||||||
|
|
||||||
|
length = buffer_view["byteLength"]
|
||||||
|
offset = buffer_view.get("byteOffset", 0)
|
||||||
|
|
||||||
|
binary_data = self.get_binary_data(buffer_view["buffer"])
|
||||||
|
|
||||||
|
bytesio = BytesIO(binary_data[offset : offset + length].tobytes())
|
||||||
|
# pyre-fixme[16]: `Image.Image` has no attribute `__enter__`.
|
||||||
|
with Image.open(bytesio) as f:
|
||||||
|
array = np.array(f)
|
||||||
|
if array.dtype == np.uint8:
|
||||||
|
return array.astype(np.float32) / 255.0
|
||||||
|
else:
|
||||||
|
return array
|
||||||
|
|
||||||
|
def _get_texture_map_image(self, image_index: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Return a texture map image as a torch tensor.
|
||||||
|
Calling this function repeatedly with the same arguments returns
|
||||||
|
the very same tensor, this allows a memory optimization to happen
|
||||||
|
later in TexturesUV.join_scene.
|
||||||
|
Any alpha channel is ignored.
|
||||||
|
"""
|
||||||
|
im = self._texture_map_images.get(image_index)
|
||||||
|
if im is not None:
|
||||||
|
return im
|
||||||
|
|
||||||
|
im = torch.from_numpy(self._access_image(image_index))[:, :, :3]
|
||||||
|
self._texture_map_images[image_index] = im
|
||||||
|
return im
|
||||||
|
|
||||||
|
def _access_data(self, accessor_index: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get the raw data from an accessor as a numpy array.
|
||||||
|
"""
|
||||||
|
accessor = self._accessors[accessor_index]
|
||||||
|
|
||||||
|
buffer_view_index = accessor.get("bufferView")
|
||||||
|
# Undefined buffer view (all zeros) are not (yet) supported
|
||||||
|
if buffer_view_index is None:
|
||||||
|
raise NotImplementedError("Undefined buffer view")
|
||||||
|
|
||||||
|
accessor_byte_offset = accessor.get("byteOffset", 0)
|
||||||
|
component_type = accessor["componentType"]
|
||||||
|
element_count = accessor["count"]
|
||||||
|
element_type = accessor["type"]
|
||||||
|
|
||||||
|
# Sparse accessors are not (yet) supported
|
||||||
|
if accessor.get("sparse") is not None:
|
||||||
|
raise NotImplementedError("Sparse Accessors")
|
||||||
|
|
||||||
|
buffer_view = self._buffer_views[buffer_view_index]
|
||||||
|
buffer_index = buffer_view["buffer"]
|
||||||
|
buffer_byte_length = buffer_view["byteLength"]
|
||||||
|
element_byte_offset = buffer_view.get("byteOffset", 0)
|
||||||
|
element_byte_stride = buffer_view.get("byteStride", 0)
|
||||||
|
if element_byte_stride != 0 and element_byte_stride < 4:
|
||||||
|
raise ValueError("Stride is too small.")
|
||||||
|
if element_byte_stride > 252:
|
||||||
|
raise ValueError("Stride is too big.")
|
||||||
|
|
||||||
|
element_shape = _ELEMENT_SHAPES[element_type]
|
||||||
|
item_type = _ITEM_TYPES[component_type]
|
||||||
|
item_dtype = np.dtype(item_type)
|
||||||
|
item_count = np.prod(element_shape)
|
||||||
|
item_size = item_dtype.itemsize
|
||||||
|
size = element_count * item_count * item_size
|
||||||
|
if size > buffer_byte_length:
|
||||||
|
raise ValueError("Buffer did not have enough data for the accessor")
|
||||||
|
|
||||||
|
buffer_ = self._buffers[buffer_index]
|
||||||
|
binary_data = self.get_binary_data(buffer_index)
|
||||||
|
if len(binary_data) < buffer_["byteLength"]:
|
||||||
|
raise ValueError("Not enough binary data for the buffer")
|
||||||
|
|
||||||
|
if element_byte_stride == 0:
|
||||||
|
element_byte_stride = item_size * item_count
|
||||||
|
# The same buffer can store interleaved elements
|
||||||
|
if element_byte_stride < item_size * item_count:
|
||||||
|
raise ValueError("Items should not overlap")
|
||||||
|
|
||||||
|
dtype = np.dtype(
|
||||||
|
{
|
||||||
|
"names": ["element"],
|
||||||
|
"formats": [str(element_shape) + item_dtype.str],
|
||||||
|
"offsets": [0],
|
||||||
|
"itemsize": element_byte_stride,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
byte_offset = accessor_byte_offset + element_byte_offset
|
||||||
|
if byte_offset % item_size != 0:
|
||||||
|
raise ValueError("Misaligned data")
|
||||||
|
byte_length = element_count * element_byte_stride
|
||||||
|
buffer_view = binary_data[byte_offset : byte_offset + byte_length].view(dtype)[
|
||||||
|
"element"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Convert matrix data from column-major (OpenGL) to row-major order
|
||||||
|
if element_type in ("MAT2", "MAT3", "MAT4"):
|
||||||
|
buffer_view = np.transpose(buffer_view, (0, 2, 1))
|
||||||
|
|
||||||
|
return buffer_view
|
||||||
|
|
||||||
|
def _get_primitive_attribute(
|
||||||
|
self, primitive_attributes: Dict[str, Any], key: str, dtype
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
accessor_index = primitive_attributes.get(key)
|
||||||
|
if accessor_index is None:
|
||||||
|
return None
|
||||||
|
primitive_attribute = self._access_data(accessor_index)
|
||||||
|
if key == "JOINTS_0":
|
||||||
|
pass
|
||||||
|
elif dtype == np.uint8:
|
||||||
|
primitive_attribute /= 255.0
|
||||||
|
elif dtype == np.uint16:
|
||||||
|
primitive_attribute /= 65535.0
|
||||||
|
else:
|
||||||
|
if dtype != np.float32:
|
||||||
|
raise ValueError("Unexpected data type")
|
||||||
|
primitive_attribute = primitive_attribute.astype(dtype)
|
||||||
|
return primitive_attribute
|
||||||
|
|
||||||
|
def get_binary_data(self, buffer_index: int):
|
||||||
|
"""
|
||||||
|
Get the binary data from a buffer as a 1D numpy array of bytes.
|
||||||
|
This is implemented for explicit uri data buffers or the main GLB data
|
||||||
|
segment.
|
||||||
|
"""
|
||||||
|
buffer_ = self._buffers[buffer_index]
|
||||||
|
binary_data = self._binary_data.get(buffer_index)
|
||||||
|
if binary_data is None: # Lazily decode binary data
|
||||||
|
uri = buffer_.get("uri")
|
||||||
|
if not uri.startswith(_DATA_URI_PREFIX):
|
||||||
|
raise NotImplementedError("Unexpected URI type")
|
||||||
|
binary_data = b64decode(uri[len(_DATA_URI_PREFIX) :])
|
||||||
|
binary_data = np.frombuffer(binary_data, dtype=np.uint8)
|
||||||
|
self._binary_data[buffer_index] = binary_data
|
||||||
|
return binary_data
|
||||||
|
|
||||||
|
def get_texture_for_mesh(
|
||||||
|
self, primitive: Dict[str, Any], indices: torch.Tensor
|
||||||
|
) -> Optional[TexturesBase]:
|
||||||
|
"""
|
||||||
|
Get the texture object representing the given mesh primitive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primitive: the mesh primitive being loaded.
|
||||||
|
indices: the face indices of the mesh
|
||||||
|
"""
|
||||||
|
attributes = primitive["attributes"]
|
||||||
|
vertex_colors = self._get_primitive_attribute(attributes, "COLOR_0", np.float32)
|
||||||
|
if vertex_colors is not None:
|
||||||
|
return TexturesVertex(torch.from_numpy(vertex_colors))
|
||||||
|
|
||||||
|
vertex_texcoords_0 = self._get_primitive_attribute(
|
||||||
|
attributes, "TEXCOORD_0", np.float32
|
||||||
|
)
|
||||||
|
if vertex_texcoords_0 is not None:
|
||||||
|
verts_uvs = torch.from_numpy(vertex_texcoords_0)
|
||||||
|
verts_uvs[:, 1] = 1 - verts_uvs[:, -1]
|
||||||
|
faces_uvs = indices
|
||||||
|
material_index = primitive.get("material", 0)
|
||||||
|
material = self._json_data["materials"][material_index]
|
||||||
|
material_roughness = material["pbrMetallicRoughness"]
|
||||||
|
if "baseColorTexture" in material_roughness:
|
||||||
|
texture_index = material_roughness["baseColorTexture"]["index"]
|
||||||
|
texture_json = self._json_data["textures"][texture_index]
|
||||||
|
# Todo - include baseColorFactor when also given
|
||||||
|
# Todo - look at the sampler
|
||||||
|
image_index = texture_json["source"]
|
||||||
|
map = self._get_texture_map_image(image_index)
|
||||||
|
elif "baseColorFactor" in material_roughness:
|
||||||
|
# Constant color?
|
||||||
|
map = torch.FloatTensor(material_roughness["baseColorFactor"])[
|
||||||
|
None, None, :3
|
||||||
|
]
|
||||||
|
texture = TexturesUV(
|
||||||
|
maps=[map], # alpha channel ignored
|
||||||
|
faces_uvs=[faces_uvs],
|
||||||
|
verts_uvs=[verts_uvs],
|
||||||
|
)
|
||||||
|
return texture
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load(self, include_textures: bool) -> List[Tuple[Optional[str], Meshes]]:
|
||||||
|
"""
|
||||||
|
Attempt to load all the meshes making up the default scene from
|
||||||
|
the file as a list of possibly-named Meshes objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_textures: Whether to try loading textures.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Meshes object containing one mesh.
|
||||||
|
"""
|
||||||
|
if self._json_data is None:
|
||||||
|
raise ValueError("Initialization problem")
|
||||||
|
|
||||||
|
# This loads the default scene from the file.
|
||||||
|
# This is usually the only one.
|
||||||
|
# It is possible to have multiple scenes, in which case
|
||||||
|
# you could choose another here instead of taking the default.
|
||||||
|
scene_index = self._json_data.get("scene")
|
||||||
|
|
||||||
|
if scene_index is None:
|
||||||
|
raise ValueError("Default scene is not specified.")
|
||||||
|
|
||||||
|
scene = self._json_data["scenes"][scene_index]
|
||||||
|
nodes = self._json_data.get("nodes", [])
|
||||||
|
meshes = self._json_data.get("meshes", [])
|
||||||
|
root_node_indices = scene["nodes"]
|
||||||
|
|
||||||
|
mesh_transform = Transform3d()
|
||||||
|
names_meshes_list: List[Tuple[Optional[str], Meshes]] = []
|
||||||
|
|
||||||
|
# Keep track and apply the transform of the scene node to mesh vertices
|
||||||
|
Q = deque([(Transform3d(), node_index) for node_index in root_node_indices])
|
||||||
|
|
||||||
|
while Q:
|
||||||
|
parent_transform, current_node_index = Q.popleft()
|
||||||
|
|
||||||
|
current_node = nodes[current_node_index]
|
||||||
|
|
||||||
|
transform = _make_node_transform(current_node)
|
||||||
|
current_transform = transform.compose(parent_transform)
|
||||||
|
|
||||||
|
if "mesh" in current_node:
|
||||||
|
mesh_index = current_node["mesh"]
|
||||||
|
mesh = meshes[mesh_index]
|
||||||
|
mesh_name = mesh.get("name", None)
|
||||||
|
mesh_transform = current_transform
|
||||||
|
|
||||||
|
for primitive in mesh["primitives"]:
|
||||||
|
attributes = primitive["attributes"]
|
||||||
|
accessor_index = attributes["POSITION"]
|
||||||
|
positions = torch.from_numpy(
|
||||||
|
self._access_data(accessor_index).copy()
|
||||||
|
)
|
||||||
|
positions = mesh_transform.transform_points(positions)
|
||||||
|
|
||||||
|
mode = primitive.get("mode", _PrimitiveMode.TRIANGLES)
|
||||||
|
if mode != _PrimitiveMode.TRIANGLES:
|
||||||
|
raise NotImplementedError("Non triangular meshes")
|
||||||
|
|
||||||
|
if "indices" in primitive:
|
||||||
|
accessor_index = primitive["indices"]
|
||||||
|
indices = self._access_data(accessor_index).astype(np.int64)
|
||||||
|
else:
|
||||||
|
indices = np.arange(0, len(positions), dtype=np.int64)
|
||||||
|
indices = torch.from_numpy(indices.reshape(-1, 3))
|
||||||
|
|
||||||
|
texture = None
|
||||||
|
if include_textures:
|
||||||
|
texture = self.get_texture_for_mesh(primitive, indices)
|
||||||
|
|
||||||
|
mesh_obj = Meshes(
|
||||||
|
verts=[positions], faces=[indices], textures=texture
|
||||||
|
)
|
||||||
|
names_meshes_list.append((mesh_name, mesh_obj))
|
||||||
|
|
||||||
|
if "children" in current_node:
|
||||||
|
children_node_indices = current_node["children"]
|
||||||
|
Q.extend(
|
||||||
|
[
|
||||||
|
(current_transform, node_index)
|
||||||
|
for node_index in children_node_indices
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return names_meshes_list
|
||||||
|
|
||||||
|
|
||||||
|
def load_meshes(
|
||||||
|
path: Union[str, Path],
|
||||||
|
path_manager: PathManager,
|
||||||
|
include_textures: bool = True,
|
||||||
|
) -> List[Tuple[Optional[str], Meshes]]:
|
||||||
|
"""
|
||||||
|
Loads all the meshes from the default scene in the given GLB file.
|
||||||
|
and returns them separately.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: path to read from
|
||||||
|
path_manager: PathManager object for interpreting the path
|
||||||
|
include_textures: whether to load textures
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (name, mesh) pairs, where the name is the optional name property
|
||||||
|
from the GLB file, or None if it is absent, and the mesh is a Meshes
|
||||||
|
object containing one mesh.
|
||||||
|
"""
|
||||||
|
with _open_file(path, path_manager, "rb") as f:
|
||||||
|
loader = _GLTFLoader(cast(BinaryIO, f))
|
||||||
|
names_meshes_list = loader.load(include_textures=include_textures)
|
||||||
|
return names_meshes_list
|
||||||
|
|
||||||
|
|
||||||
|
class MeshGlbFormat(MeshFormatInterpreter):
|
||||||
|
"""
|
||||||
|
Implements loading meshes from glTF 2 assets stored in a
|
||||||
|
GLB container file or a glTF JSON file with embedded binary data.
|
||||||
|
|
||||||
|
This implementation is quite restricted in what it supports.
|
||||||
|
|
||||||
|
- It does not try to validate the input against the standard.
|
||||||
|
- It loads the default scene only.
|
||||||
|
- Only triangulated geometry is supported.
|
||||||
|
- The geometry of all meshes of the entire scene is aggregated into a single mesh.
|
||||||
|
Use `load_meshes()` instead to get un-aggregated (but transformed) ones.
|
||||||
|
- All material properties are ignored except for either vertex color, baseColorTexture
|
||||||
|
or baseColorFactor. If available, one of these (in this order) is exclusively
|
||||||
|
used which does not match the semantics of the standard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.known_suffixes = (".glb",)
|
||||||
|
|
||||||
|
def read(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
include_textures: bool,
|
||||||
|
device,
|
||||||
|
path_manager: PathManager,
|
||||||
|
**kwargs,
|
||||||
|
) -> Optional[Meshes]:
|
||||||
|
if not endswith(path, self.known_suffixes):
|
||||||
|
return None
|
||||||
|
|
||||||
|
names_meshes_list = load_meshes(
|
||||||
|
path=path,
|
||||||
|
path_manager=path_manager,
|
||||||
|
include_textures=include_textures,
|
||||||
|
)
|
||||||
|
|
||||||
|
meshes_list = [mesh for name, mesh in names_meshes_list]
|
||||||
|
mesh = join_meshes_as_scene(meshes_list)
|
||||||
|
return mesh.to(device)
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
data: Meshes,
|
||||||
|
path: Union[str, Path],
|
||||||
|
path_manager: PathManager,
|
||||||
|
binary: Optional[bool],
|
||||||
|
**kwargs,
|
||||||
|
) -> bool:
|
||||||
|
return False
|
@ -86,6 +86,8 @@ class IO:
|
|||||||
interpreter: the new interpreter to use, which must be an instance
|
interpreter: the new interpreter to use, which must be an instance
|
||||||
of a class which inherits MeshFormatInterpreter.
|
of a class which inherits MeshFormatInterpreter.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(interpreter, MeshFormatInterpreter):
|
||||||
|
raise ValueError("Invalid interpreter")
|
||||||
self.mesh_interpreters.appendleft(interpreter)
|
self.mesh_interpreters.appendleft(interpreter)
|
||||||
|
|
||||||
def register_pointcloud_format(
|
def register_pointcloud_format(
|
||||||
@ -98,6 +100,8 @@ class IO:
|
|||||||
interpreter: the new interpreter to use, which must be an instance
|
interpreter: the new interpreter to use, which must be an instance
|
||||||
of a class which inherits PointcloudFormatInterpreter.
|
of a class which inherits PointcloudFormatInterpreter.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(interpreter, PointcloudFormatInterpreter):
|
||||||
|
raise ValueError("Invalid interpreter")
|
||||||
self.pointcloud_interpreters.appendleft(interpreter)
|
self.pointcloud_interpreters.appendleft(interpreter)
|
||||||
|
|
||||||
def load_mesh(
|
def load_mesh(
|
||||||
|
@ -22,7 +22,7 @@ from .shader import (
|
|||||||
)
|
)
|
||||||
from .shading import gouraud_shading, phong_shading
|
from .shading import gouraud_shading, phong_shading
|
||||||
from .textures import Textures # DEPRECATED
|
from .textures import Textures # DEPRECATED
|
||||||
from .textures import TexturesAtlas, TexturesUV, TexturesVertex
|
from .textures import TexturesAtlas, TexturesBase, TexturesUV, TexturesVertex
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
@ -155,7 +155,7 @@ class TestCaseMixin(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not close and msg is None:
|
if not close and msg is None:
|
||||||
diff = backend.abs(input - other) + 0.0
|
diff = backend.abs(input + 0.0 - other)
|
||||||
ratio = diff / backend.abs(other)
|
ratio = diff / backend.abs(other)
|
||||||
try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
|
try_relative = (diff <= atol) | (backend.isfinite(ratio) & (ratio > 0))
|
||||||
if try_relative.all():
|
if try_relative.all():
|
||||||
@ -167,7 +167,9 @@ class TestCaseMixin(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
extra = ""
|
extra = ""
|
||||||
shape = tuple(input.shape)
|
shape = tuple(input.shape)
|
||||||
|
loc = np.unravel_index(diff.argmax(), shape)
|
||||||
max_diff = diff.max()
|
max_diff = diff.max()
|
||||||
self.fail(f"Not close. Max diff {max_diff}.{extra} Shape {shape}.")
|
msg = f"Not close. Max diff {max_diff}.{extra} Shape {shape}. At {loc}."
|
||||||
|
self.fail(msg)
|
||||||
|
|
||||||
self.assertTrue(close, msg)
|
self.assertTrue(close, msg)
|
||||||
|
BIN
tests/data/cow.glb
Normal file
BIN
tests/data/cow.glb
Normal file
Binary file not shown.
BIN
tests/data/glb_cow.png
Normal file
BIN
tests/data/glb_cow.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 10 KiB |
BIN
tests/data/glb_cow_gray.png
Normal file
BIN
tests/data/glb_cow_gray.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 5.3 KiB |
180
tests/test_io_gltf.py
Normal file
180
tests/test_io_gltf.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from math import radians
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, get_pytorch3d_dir, get_tests_dir
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch3d.io import IO
|
||||||
|
from pytorch3d.io.experimental_gltf_io import MeshGlbFormat
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
AmbientLights,
|
||||||
|
BlendParams,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
look_at_view_transform,
|
||||||
|
rotate_on_spot,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.mesh import (
|
||||||
|
HardPhongShader,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
TexturesVertex,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from pytorch3d.transforms import axis_angle_to_matrix
|
||||||
|
from pytorch3d.vis.texture_vis import texturesuv_image_PIL
|
||||||
|
|
||||||
|
|
||||||
|
DATA_DIR = get_tests_dir() / "data"
|
||||||
|
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
|
def _load(path, **kwargs) -> Meshes:
|
||||||
|
io = IO()
|
||||||
|
io.register_meshes_format(MeshGlbFormat())
|
||||||
|
return io.load_mesh(path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _render(
|
||||||
|
mesh: Meshes,
|
||||||
|
name: str,
|
||||||
|
dist: float = 3.0,
|
||||||
|
elev: float = 10.0,
|
||||||
|
azim: float = 0,
|
||||||
|
image_size: int = 256,
|
||||||
|
pan=None,
|
||||||
|
RT=None,
|
||||||
|
use_ambient=False,
|
||||||
|
):
|
||||||
|
device = mesh.device
|
||||||
|
if RT is not None:
|
||||||
|
R, T = RT
|
||||||
|
else:
|
||||||
|
R, T = look_at_view_transform(dist, elev, azim)
|
||||||
|
if pan is not None:
|
||||||
|
R, T = rotate_on_spot(R, T, pan)
|
||||||
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
|
raster_settings = RasterizationSettings(
|
||||||
|
image_size=image_size, blur_radius=0.0, faces_per_pixel=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init shader settings
|
||||||
|
if use_ambient:
|
||||||
|
lights = AmbientLights(device=device)
|
||||||
|
else:
|
||||||
|
lights = PointLights(device=device)
|
||||||
|
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||||
|
|
||||||
|
blend_params = BlendParams(
|
||||||
|
sigma=1e-1,
|
||||||
|
gamma=1e-4,
|
||||||
|
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||||
|
)
|
||||||
|
# Init renderer
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||||
|
shader=HardPhongShader(
|
||||||
|
device=device, lights=lights, cameras=cameras, blend_params=blend_params
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = renderer(mesh)
|
||||||
|
|
||||||
|
image = (output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray(image).save(DATA_DIR / f"glb_{name}_.png")
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||||
|
def test_load_apartment(self):
|
||||||
|
"""
|
||||||
|
This is the example habitat example scene from inside
|
||||||
|
http://dl.fbaipublicfiles.com/habitat/habitat-test-scenes.zip
|
||||||
|
|
||||||
|
The scene is "already lit", i.e. the textures reflect the lighting
|
||||||
|
already, so we want to render them with full ambient light.
|
||||||
|
"""
|
||||||
|
self.skipTest("Data not available")
|
||||||
|
|
||||||
|
glb = DATA_DIR / "apartment_1.glb"
|
||||||
|
self.assertTrue(glb.is_file())
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
mesh = _load(glb, device=device)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
texturesuv_image_PIL(mesh.textures).save(DATA_DIR / "out_apartment.png")
|
||||||
|
|
||||||
|
for i in range(19):
|
||||||
|
# random locations in the apartment
|
||||||
|
eye = ((np.random.uniform(-6, 0.5), np.random.uniform(-8, 2), 0),)
|
||||||
|
at = ((np.random.uniform(-6, 0.5), np.random.uniform(-8, 2), 0),)
|
||||||
|
up = ((0, 0, -1),)
|
||||||
|
RT = look_at_view_transform(eye=eye, at=at, up=up)
|
||||||
|
_render(mesh, f"apartment_eau{i}", RT=RT, use_ambient=True)
|
||||||
|
|
||||||
|
for i in range(12):
|
||||||
|
# panning around the inner room from one location
|
||||||
|
pan = axis_angle_to_matrix(torch.FloatTensor([0, radians(30 * i), 0]))
|
||||||
|
_render(mesh, f"apartment{i}", 1.0, -90, pan, use_ambient=True)
|
||||||
|
|
||||||
|
def test_load_cow(self):
|
||||||
|
"""
|
||||||
|
Load the cow as converted to a single mesh in a glb file.
|
||||||
|
"""
|
||||||
|
glb = DATA_DIR / "cow.glb"
|
||||||
|
self.assertTrue(glb.is_file())
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
mesh = _load(glb, device=device)
|
||||||
|
self.assertEqual(mesh.device, device)
|
||||||
|
|
||||||
|
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
|
||||||
|
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
|
||||||
|
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
|
||||||
|
self.assertClose(
|
||||||
|
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
|
||||||
|
)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
texturesuv_image_PIL(mesh.textures).save(DATA_DIR / "out_cow.png")
|
||||||
|
image = _render(mesh, "cow", azim=4)
|
||||||
|
|
||||||
|
with Image.open(DATA_DIR / "glb_cow.png") as f:
|
||||||
|
expected = np.array(f)
|
||||||
|
|
||||||
|
self.assertClose(image, expected)
|
||||||
|
|
||||||
|
def test_load_cow_no_texture(self):
|
||||||
|
"""
|
||||||
|
Load the cow as converted to a single mesh in a glb file.
|
||||||
|
"""
|
||||||
|
glb = DATA_DIR / "cow.glb"
|
||||||
|
self.assertTrue(glb.is_file())
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
mesh = _load(glb, device=device, include_textures=False)
|
||||||
|
self.assertEqual(len(mesh), 1)
|
||||||
|
self.assertIsNone(mesh.textures)
|
||||||
|
|
||||||
|
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
|
||||||
|
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
|
||||||
|
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
|
||||||
|
self.assertClose(
|
||||||
|
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
|
||||||
|
)
|
||||||
|
|
||||||
|
mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))
|
||||||
|
|
||||||
|
image = _render(mesh, "cow_gray")
|
||||||
|
|
||||||
|
with Image.open(DATA_DIR / "glb_cow_gray.png") as f:
|
||||||
|
expected = np.array(f)
|
||||||
|
|
||||||
|
self.assertClose(image, expected)
|
Loading…
x
Reference in New Issue
Block a user