mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Write meshes to GLB
Summary: Write the amalgamated mesh from the Mesh module to glb. In this version, the json header and the binary data specified by the buffer are merged into glb. The image texture attributes are added. Reviewed By: bottler Differential Revision: D41489778 fbshipit-source-id: 3af0e9a8f9e9098e73737a254177802e0fb6bd3c
This commit is contained in:
parent
dba48fb410
commit
cc2840eb44
@ -39,7 +39,7 @@ import json
|
||||
import struct
|
||||
import warnings
|
||||
from base64 import b64decode
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
|
||||
@ -102,6 +102,27 @@ _ELEMENT_SHAPES: Dict[str, _ElementShape] = {
|
||||
"MAT4": (4, 4),
|
||||
}
|
||||
|
||||
_DTYPE_BYTES: Dict[Any, int] = {
|
||||
np.int8: 1,
|
||||
np.uint8: 1,
|
||||
np.int16: 2,
|
||||
np.uint16: 2,
|
||||
np.uint32: 4,
|
||||
np.float32: 4,
|
||||
}
|
||||
|
||||
|
||||
class _TargetType(IntEnum):
|
||||
ARRAY_BUFFER = 34962
|
||||
ELEMENT_ARRAY_BUFFER = 34963
|
||||
|
||||
|
||||
class OurEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.int64):
|
||||
return str(obj)
|
||||
return super(OurEncoder, self).default(obj)
|
||||
|
||||
|
||||
def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
|
||||
header = stream.read(12)
|
||||
@ -109,7 +130,6 @@ def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
|
||||
|
||||
if magic != _GLTF_MAGIC:
|
||||
return None
|
||||
|
||||
return version, length
|
||||
|
||||
|
||||
@ -227,7 +247,6 @@ class _GLTFLoader:
|
||||
offset = buffer_view.get("byteOffset", 0)
|
||||
|
||||
binary_data = self.get_binary_data(buffer_view["buffer"])
|
||||
|
||||
bytesio = BytesIO(binary_data[offset : offset + length].tobytes())
|
||||
with Image.open(bytesio) as f:
|
||||
array = np.array(f)
|
||||
@ -521,6 +540,223 @@ def load_meshes(
|
||||
return names_meshes_list
|
||||
|
||||
|
||||
class _GLTFWriter:
|
||||
def __init__(self, data: Meshes, buffer_stream: BinaryIO) -> None:
|
||||
self._json_data = defaultdict(list)
|
||||
self.mesh = data
|
||||
self.buffer_stream = buffer_stream
|
||||
|
||||
# initialize json with one scene and one node
|
||||
scene_index = 0
|
||||
# pyre-fixme[6]: Incompatible parameter type
|
||||
self._json_data["scene"] = scene_index
|
||||
self._json_data["scenes"].append({"nodes": [scene_index]})
|
||||
self._json_data["asset"] = {"version": "2.0"}
|
||||
node = {"name": "Node", "mesh": 0}
|
||||
self._json_data["nodes"].append(node)
|
||||
|
||||
# mesh primitives
|
||||
meshes = defaultdict(list)
|
||||
# pyre-fixme[6]: Incompatible parameter type
|
||||
meshes["name"] = "Node-Mesh"
|
||||
primitives = {
|
||||
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
|
||||
"indices": 1,
|
||||
"material": 0, # default material
|
||||
"mode": _PrimitiveMode.TRIANGLES,
|
||||
}
|
||||
meshes["primitives"].append(primitives)
|
||||
self._json_data["meshes"].append(meshes)
|
||||
|
||||
# default material
|
||||
material = {
|
||||
"name": "material_1",
|
||||
"pbrMetallicRoughness": {
|
||||
"baseColorTexture": {"index": 0},
|
||||
"baseColorFactor": [1, 1, 1, 1],
|
||||
"metallicFactor": 0,
|
||||
"roughnessFactor": 0.99,
|
||||
},
|
||||
"emissiveFactor": [0, 0, 0],
|
||||
"alphaMode": "OPAQUE",
|
||||
}
|
||||
self._json_data["materials"].append(material)
|
||||
|
||||
# default sampler
|
||||
sampler = {"magFilter": 9729, "minFilter": 9986, "wrapS": 10497, "wrapT": 10497}
|
||||
self._json_data["samplers"].append(sampler)
|
||||
|
||||
# default textures
|
||||
texture = {"sampler": 0, "source": 0}
|
||||
self._json_data["textures"].append(texture)
|
||||
|
||||
def _write_accessor_json(self, key: str) -> Tuple[int, np.ndarray]:
|
||||
name = "Node-Mesh_%s" % key
|
||||
byte_offset = 0
|
||||
if key == "positions":
|
||||
data = self.mesh.verts_packed().cpu().numpy()
|
||||
component_type = _ComponentType.FLOAT
|
||||
element_type = "VEC3"
|
||||
buffer_view = 0
|
||||
element_min = list(map(float, np.min(data, axis=0)))
|
||||
element_max = list(map(float, np.max(data, axis=0)))
|
||||
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
elif key == "texcoords":
|
||||
component_type = _ComponentType.FLOAT
|
||||
data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy()
|
||||
data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate
|
||||
element_type = "VEC2"
|
||||
buffer_view = 2
|
||||
element_min = list(map(float, np.min(data, axis=0)))
|
||||
element_max = list(map(float, np.max(data, axis=0)))
|
||||
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
elif key == "indices":
|
||||
component_type = _ComponentType.UNSIGNED_SHORT
|
||||
data = (
|
||||
self.mesh.faces_packed()
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(_ITEM_TYPES[component_type])
|
||||
)
|
||||
element_type = "SCALAR"
|
||||
buffer_view = 1
|
||||
element_min = list(map(int, np.min(data, keepdims=True)))
|
||||
element_max = list(map(int, np.max(data, keepdims=True)))
|
||||
byte_per_element = (
|
||||
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"invalid key accessor, should be one of positions, indices or texcoords"
|
||||
)
|
||||
|
||||
count = int(data.shape[0])
|
||||
byte_length = count * byte_per_element
|
||||
accessor_json = {
|
||||
"name": name,
|
||||
"componentType": component_type,
|
||||
"type": element_type,
|
||||
"bufferView": buffer_view,
|
||||
"byteOffset": byte_offset,
|
||||
"min": element_min,
|
||||
"max": element_max,
|
||||
"count": count * 3 if key == "indices" else count,
|
||||
}
|
||||
self._json_data["accessors"].append(accessor_json)
|
||||
return (byte_length, data)
|
||||
|
||||
def _write_bufferview(self, key: str, **kwargs):
|
||||
if key not in ["positions", "texcoords", "indices"]:
|
||||
raise ValueError("key must be one of positions, texcoords or indices")
|
||||
|
||||
bufferview = {
|
||||
"name": "bufferView_%s" % key,
|
||||
"buffer": 0,
|
||||
}
|
||||
target = _TargetType.ARRAY_BUFFER
|
||||
if key == "positions":
|
||||
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
bufferview["byteStride"] = int(byte_per_element)
|
||||
elif key == "texcoords":
|
||||
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
|
||||
target = _TargetType.ARRAY_BUFFER
|
||||
bufferview["byteStride"] = int(byte_per_element)
|
||||
elif key == "indices":
|
||||
byte_per_element = (
|
||||
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
|
||||
)
|
||||
target = _TargetType.ELEMENT_ARRAY_BUFFER
|
||||
|
||||
bufferview["target"] = target
|
||||
bufferview["byteOffset"] = kwargs.get("offset")
|
||||
bufferview["byteLength"] = kwargs.get("byte_length")
|
||||
self._json_data["bufferViews"].append(bufferview)
|
||||
|
||||
def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]:
|
||||
image_np = self.mesh.textures.maps_list()[0].cpu().numpy()
|
||||
image_array = (image_np * 255.0).astype(np.uint8)
|
||||
im = Image.fromarray(image_array)
|
||||
with BytesIO() as f:
|
||||
im.save(f, format="PNG")
|
||||
image_data = f.getvalue()
|
||||
|
||||
image_data_byte_length = len(image_data)
|
||||
bufferview_image = {
|
||||
"buffer": 0,
|
||||
}
|
||||
bufferview_image["byteOffset"] = kwargs.get("offset")
|
||||
bufferview_image["byteLength"] = image_data_byte_length
|
||||
self._json_data["bufferViews"].append(bufferview_image)
|
||||
|
||||
image = {"name": "texture", "mimeType": "image/png", "bufferView": 3}
|
||||
self._json_data["images"].append(image)
|
||||
return (image_data_byte_length, image_data)
|
||||
|
||||
def save(self):
|
||||
# check validity of mesh
|
||||
if self.mesh.verts_packed() is None or self.mesh.faces_packed() is None:
|
||||
raise ValueError("invalid mesh to save, verts or face indices are empty")
|
||||
|
||||
# accessors for positions, texture uvs and face indices
|
||||
pos_byte, pos_data = self._write_accessor_json("positions")
|
||||
idx_byte, idx_data = self._write_accessor_json("indices")
|
||||
include_textures = False
|
||||
if (
|
||||
self.mesh.textures is not None
|
||||
and self.mesh.textures.verts_uvs_list()[0] is not None
|
||||
):
|
||||
tex_byte, tex_data = self._write_accessor_json("texcoords")
|
||||
include_textures = True
|
||||
|
||||
# bufferViews for positions, texture coords and indices
|
||||
byte_offset = 0
|
||||
self._write_bufferview("positions", byte_length=pos_byte, offset=byte_offset)
|
||||
byte_offset += pos_byte
|
||||
|
||||
self._write_bufferview("indices", byte_length=idx_byte, offset=byte_offset)
|
||||
byte_offset += idx_byte
|
||||
|
||||
if include_textures:
|
||||
self._write_bufferview(
|
||||
"texcoords", byte_length=tex_byte, offset=byte_offset
|
||||
)
|
||||
byte_offset += tex_byte
|
||||
|
||||
# image bufferView
|
||||
include_image = False
|
||||
if (
|
||||
self.mesh.textures is not None
|
||||
and self.mesh.textures.maps_list()[0] is not None
|
||||
):
|
||||
include_image = True
|
||||
image_byte, image_data = self._write_image_buffer(offset=byte_offset)
|
||||
byte_offset += image_byte
|
||||
|
||||
# buffers
|
||||
self._json_data["buffers"].append({"byteLength": int(byte_offset)})
|
||||
|
||||
# organize into a glb
|
||||
json_bytes = bytes(json.dumps(self._json_data, cls=OurEncoder), "utf-8")
|
||||
json_length = len(json_bytes)
|
||||
|
||||
# write header
|
||||
header = struct.pack("<III", _GLTF_MAGIC, 2, json_length + byte_offset)
|
||||
self.buffer_stream.write(header)
|
||||
|
||||
# write json
|
||||
self.buffer_stream.write(struct.pack("<II", json_length, _JSON_CHUNK_TYPE))
|
||||
self.buffer_stream.write(json_bytes)
|
||||
|
||||
# write binary data
|
||||
self.buffer_stream.write(struct.pack("<II", byte_offset, _BINARY_CHUNK_TYPE))
|
||||
self.buffer_stream.write(pos_data)
|
||||
self.buffer_stream.write(idx_data)
|
||||
if include_textures:
|
||||
self.buffer_stream.write(tex_data)
|
||||
if include_image:
|
||||
self.buffer_stream.write(image_data)
|
||||
|
||||
|
||||
class MeshGlbFormat(MeshFormatInterpreter):
|
||||
"""
|
||||
Implements loading meshes from glTF 2 assets stored in a
|
||||
@ -570,4 +806,21 @@ class MeshGlbFormat(MeshFormatInterpreter):
|
||||
binary: Optional[bool],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
return False
|
||||
"""
|
||||
Writes all the meshes from the default scene to GLB file.
|
||||
|
||||
Args:
|
||||
data: meshes to save
|
||||
path: path of the GLB file to write into
|
||||
path_manager: PathManager object for interpreting the path
|
||||
|
||||
Return True if saving succeeds and False otherwise
|
||||
"""
|
||||
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return False
|
||||
|
||||
with _open_file(path, path_manager, "wb") as f:
|
||||
writer = _GLTFWriter(data, cast(BinaryIO, f))
|
||||
writer.save()
|
||||
return True
|
||||
|
@ -29,6 +29,7 @@ from pytorch3d.renderer.mesh import (
|
||||
)
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.transforms import axis_angle_to_matrix
|
||||
from pytorch3d.utils import ico_sphere
|
||||
from pytorch3d.vis.texture_vis import texturesuv_image_PIL
|
||||
|
||||
from .common_testing import get_pytorch3d_dir, get_tests_dir, TestCaseMixin
|
||||
@ -45,6 +46,12 @@ def _load(path, **kwargs) -> Meshes:
|
||||
return io.load_mesh(path, **kwargs)
|
||||
|
||||
|
||||
def _write(mesh, path, **kwargs) -> bool:
|
||||
io = IO()
|
||||
io.register_meshes_format(MeshGlbFormat())
|
||||
return io.save_mesh(mesh, path, **kwargs)
|
||||
|
||||
|
||||
def _render(
|
||||
mesh: Meshes,
|
||||
name: str,
|
||||
@ -144,9 +151,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
|
||||
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
|
||||
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
|
||||
self.assertClose(
|
||||
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
|
||||
)
|
||||
self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
|
||||
|
||||
self.assertClose(
|
||||
mesh.textures.verts_uvs_padded().cpu(), mesh_obj.textures.verts_uvs_padded()
|
||||
@ -169,6 +174,69 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self.assertClose(image, expected)
|
||||
|
||||
def test_save_cow(self):
|
||||
"""
|
||||
Save the cow mesh to a glb file
|
||||
"""
|
||||
# load cow mesh from a glb file
|
||||
glb = DATA_DIR / "cow.glb"
|
||||
self.assertTrue(glb.is_file())
|
||||
device = torch.device("cuda:0")
|
||||
mesh = _load(glb, device=device)
|
||||
|
||||
# save the mesh to a glb file
|
||||
glb = DATA_DIR / "cow_write.glb"
|
||||
_write(mesh, glb)
|
||||
|
||||
# load again
|
||||
glb_reload = DATA_DIR / "cow_write.glb"
|
||||
self.assertTrue(glb_reload.is_file())
|
||||
device = torch.device("cuda:0")
|
||||
mesh_reload = _load(glb_reload, device=device)
|
||||
|
||||
# assertions
|
||||
self.assertEqual(mesh_reload.faces_packed().shape, (5856, 3))
|
||||
self.assertEqual(mesh_reload.verts_packed().shape, (3225, 3))
|
||||
self.assertClose(
|
||||
mesh_reload.get_bounding_boxes().cpu(), mesh.get_bounding_boxes().cpu()
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
mesh_reload.textures.verts_uvs_padded().cpu(),
|
||||
mesh.textures.verts_uvs_padded().cpu(),
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
mesh_reload.textures.faces_uvs_padded().cpu(),
|
||||
mesh.textures.faces_uvs_padded().cpu(),
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
mesh_reload.textures.maps_padded().cpu(), mesh.textures.maps_padded().cpu()
|
||||
)
|
||||
|
||||
def test_save_ico_sphere(self):
|
||||
"""
|
||||
save the ico_sphere mesh in a glb file
|
||||
"""
|
||||
ico_sphere_mesh = ico_sphere(level=3)
|
||||
glb = DATA_DIR / "ico_sphere.glb"
|
||||
_write(ico_sphere_mesh, glb)
|
||||
|
||||
# reload the ico_sphere
|
||||
device = torch.device("cuda:0")
|
||||
mesh_reload = _load(glb, device=device, include_textures=False)
|
||||
|
||||
self.assertClose(
|
||||
ico_sphere_mesh.verts_padded().cpu(),
|
||||
mesh_reload.verts_padded().cpu(),
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
ico_sphere_mesh.faces_padded().cpu(),
|
||||
mesh_reload.faces_padded().cpu(),
|
||||
)
|
||||
|
||||
def test_load_cow_no_texture(self):
|
||||
"""
|
||||
Load the cow as converted to a single mesh in a glb file.
|
||||
@ -183,9 +251,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(mesh.faces_packed().shape, (5856, 3))
|
||||
self.assertEqual(mesh.verts_packed().shape, (3225, 3))
|
||||
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
|
||||
self.assertClose(
|
||||
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
|
||||
)
|
||||
self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
|
||||
|
||||
mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user