Write meshes to GLB

Summary: Write the amalgamated mesh from the Mesh module to glb. In this version, the json header and the binary data specified by the buffer are merged into glb. The image texture attributes are added.

Reviewed By: bottler

Differential Revision: D41489778

fbshipit-source-id: 3af0e9a8f9e9098e73737a254177802e0fb6bd3c
This commit is contained in:
Jiali Duan 2022-12-05 01:25:43 -08:00 committed by Facebook GitHub Bot
parent dba48fb410
commit cc2840eb44
2 changed files with 329 additions and 10 deletions

View File

@ -39,7 +39,7 @@ import json
import struct import struct
import warnings import warnings
from base64 import b64decode from base64 import b64decode
from collections import deque from collections import defaultdict, deque
from enum import IntEnum from enum import IntEnum
from io import BytesIO from io import BytesIO
from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union
@ -102,6 +102,27 @@ _ELEMENT_SHAPES: Dict[str, _ElementShape] = {
"MAT4": (4, 4), "MAT4": (4, 4),
} }
_DTYPE_BYTES: Dict[Any, int] = {
np.int8: 1,
np.uint8: 1,
np.int16: 2,
np.uint16: 2,
np.uint32: 4,
np.float32: 4,
}
class _TargetType(IntEnum):
ARRAY_BUFFER = 34962
ELEMENT_ARRAY_BUFFER = 34963
class OurEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.int64):
return str(obj)
return super(OurEncoder, self).default(obj)
def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]: def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
header = stream.read(12) header = stream.read(12)
@ -109,7 +130,6 @@ def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]:
if magic != _GLTF_MAGIC: if magic != _GLTF_MAGIC:
return None return None
return version, length return version, length
@ -227,7 +247,6 @@ class _GLTFLoader:
offset = buffer_view.get("byteOffset", 0) offset = buffer_view.get("byteOffset", 0)
binary_data = self.get_binary_data(buffer_view["buffer"]) binary_data = self.get_binary_data(buffer_view["buffer"])
bytesio = BytesIO(binary_data[offset : offset + length].tobytes()) bytesio = BytesIO(binary_data[offset : offset + length].tobytes())
with Image.open(bytesio) as f: with Image.open(bytesio) as f:
array = np.array(f) array = np.array(f)
@ -521,6 +540,223 @@ def load_meshes(
return names_meshes_list return names_meshes_list
class _GLTFWriter:
def __init__(self, data: Meshes, buffer_stream: BinaryIO) -> None:
self._json_data = defaultdict(list)
self.mesh = data
self.buffer_stream = buffer_stream
# initialize json with one scene and one node
scene_index = 0
# pyre-fixme[6]: Incompatible parameter type
self._json_data["scene"] = scene_index
self._json_data["scenes"].append({"nodes": [scene_index]})
self._json_data["asset"] = {"version": "2.0"}
node = {"name": "Node", "mesh": 0}
self._json_data["nodes"].append(node)
# mesh primitives
meshes = defaultdict(list)
# pyre-fixme[6]: Incompatible parameter type
meshes["name"] = "Node-Mesh"
primitives = {
"attributes": {"POSITION": 0, "TEXCOORD_0": 2},
"indices": 1,
"material": 0, # default material
"mode": _PrimitiveMode.TRIANGLES,
}
meshes["primitives"].append(primitives)
self._json_data["meshes"].append(meshes)
# default material
material = {
"name": "material_1",
"pbrMetallicRoughness": {
"baseColorTexture": {"index": 0},
"baseColorFactor": [1, 1, 1, 1],
"metallicFactor": 0,
"roughnessFactor": 0.99,
},
"emissiveFactor": [0, 0, 0],
"alphaMode": "OPAQUE",
}
self._json_data["materials"].append(material)
# default sampler
sampler = {"magFilter": 9729, "minFilter": 9986, "wrapS": 10497, "wrapT": 10497}
self._json_data["samplers"].append(sampler)
# default textures
texture = {"sampler": 0, "source": 0}
self._json_data["textures"].append(texture)
def _write_accessor_json(self, key: str) -> Tuple[int, np.ndarray]:
name = "Node-Mesh_%s" % key
byte_offset = 0
if key == "positions":
data = self.mesh.verts_packed().cpu().numpy()
component_type = _ComponentType.FLOAT
element_type = "VEC3"
buffer_view = 0
element_min = list(map(float, np.min(data, axis=0)))
element_max = list(map(float, np.max(data, axis=0)))
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
elif key == "texcoords":
component_type = _ComponentType.FLOAT
data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy()
data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate
element_type = "VEC2"
buffer_view = 2
element_min = list(map(float, np.min(data, axis=0)))
element_max = list(map(float, np.max(data, axis=0)))
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
elif key == "indices":
component_type = _ComponentType.UNSIGNED_SHORT
data = (
self.mesh.faces_packed()
.cpu()
.numpy()
.astype(_ITEM_TYPES[component_type])
)
element_type = "SCALAR"
buffer_view = 1
element_min = list(map(int, np.min(data, keepdims=True)))
element_max = list(map(int, np.max(data, keepdims=True)))
byte_per_element = (
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
)
else:
raise NotImplementedError(
"invalid key accessor, should be one of positions, indices or texcoords"
)
count = int(data.shape[0])
byte_length = count * byte_per_element
accessor_json = {
"name": name,
"componentType": component_type,
"type": element_type,
"bufferView": buffer_view,
"byteOffset": byte_offset,
"min": element_min,
"max": element_max,
"count": count * 3 if key == "indices" else count,
}
self._json_data["accessors"].append(accessor_json)
return (byte_length, data)
def _write_bufferview(self, key: str, **kwargs):
if key not in ["positions", "texcoords", "indices"]:
raise ValueError("key must be one of positions, texcoords or indices")
bufferview = {
"name": "bufferView_%s" % key,
"buffer": 0,
}
target = _TargetType.ARRAY_BUFFER
if key == "positions":
byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
bufferview["byteStride"] = int(byte_per_element)
elif key == "texcoords":
byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]]
target = _TargetType.ARRAY_BUFFER
bufferview["byteStride"] = int(byte_per_element)
elif key == "indices":
byte_per_element = (
3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]]
)
target = _TargetType.ELEMENT_ARRAY_BUFFER
bufferview["target"] = target
bufferview["byteOffset"] = kwargs.get("offset")
bufferview["byteLength"] = kwargs.get("byte_length")
self._json_data["bufferViews"].append(bufferview)
def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]:
image_np = self.mesh.textures.maps_list()[0].cpu().numpy()
image_array = (image_np * 255.0).astype(np.uint8)
im = Image.fromarray(image_array)
with BytesIO() as f:
im.save(f, format="PNG")
image_data = f.getvalue()
image_data_byte_length = len(image_data)
bufferview_image = {
"buffer": 0,
}
bufferview_image["byteOffset"] = kwargs.get("offset")
bufferview_image["byteLength"] = image_data_byte_length
self._json_data["bufferViews"].append(bufferview_image)
image = {"name": "texture", "mimeType": "image/png", "bufferView": 3}
self._json_data["images"].append(image)
return (image_data_byte_length, image_data)
def save(self):
# check validity of mesh
if self.mesh.verts_packed() is None or self.mesh.faces_packed() is None:
raise ValueError("invalid mesh to save, verts or face indices are empty")
# accessors for positions, texture uvs and face indices
pos_byte, pos_data = self._write_accessor_json("positions")
idx_byte, idx_data = self._write_accessor_json("indices")
include_textures = False
if (
self.mesh.textures is not None
and self.mesh.textures.verts_uvs_list()[0] is not None
):
tex_byte, tex_data = self._write_accessor_json("texcoords")
include_textures = True
# bufferViews for positions, texture coords and indices
byte_offset = 0
self._write_bufferview("positions", byte_length=pos_byte, offset=byte_offset)
byte_offset += pos_byte
self._write_bufferview("indices", byte_length=idx_byte, offset=byte_offset)
byte_offset += idx_byte
if include_textures:
self._write_bufferview(
"texcoords", byte_length=tex_byte, offset=byte_offset
)
byte_offset += tex_byte
# image bufferView
include_image = False
if (
self.mesh.textures is not None
and self.mesh.textures.maps_list()[0] is not None
):
include_image = True
image_byte, image_data = self._write_image_buffer(offset=byte_offset)
byte_offset += image_byte
# buffers
self._json_data["buffers"].append({"byteLength": int(byte_offset)})
# organize into a glb
json_bytes = bytes(json.dumps(self._json_data, cls=OurEncoder), "utf-8")
json_length = len(json_bytes)
# write header
header = struct.pack("<III", _GLTF_MAGIC, 2, json_length + byte_offset)
self.buffer_stream.write(header)
# write json
self.buffer_stream.write(struct.pack("<II", json_length, _JSON_CHUNK_TYPE))
self.buffer_stream.write(json_bytes)
# write binary data
self.buffer_stream.write(struct.pack("<II", byte_offset, _BINARY_CHUNK_TYPE))
self.buffer_stream.write(pos_data)
self.buffer_stream.write(idx_data)
if include_textures:
self.buffer_stream.write(tex_data)
if include_image:
self.buffer_stream.write(image_data)
class MeshGlbFormat(MeshFormatInterpreter): class MeshGlbFormat(MeshFormatInterpreter):
""" """
Implements loading meshes from glTF 2 assets stored in a Implements loading meshes from glTF 2 assets stored in a
@ -570,4 +806,21 @@ class MeshGlbFormat(MeshFormatInterpreter):
binary: Optional[bool], binary: Optional[bool],
**kwargs, **kwargs,
) -> bool: ) -> bool:
return False """
Writes all the meshes from the default scene to GLB file.
Args:
data: meshes to save
path: path of the GLB file to write into
path_manager: PathManager object for interpreting the path
Return True if saving succeeds and False otherwise
"""
if not endswith(path, self.known_suffixes):
return False
with _open_file(path, path_manager, "wb") as f:
writer = _GLTFWriter(data, cast(BinaryIO, f))
writer.save()
return True

View File

@ -29,6 +29,7 @@ from pytorch3d.renderer.mesh import (
) )
from pytorch3d.structures import Meshes from pytorch3d.structures import Meshes
from pytorch3d.transforms import axis_angle_to_matrix from pytorch3d.transforms import axis_angle_to_matrix
from pytorch3d.utils import ico_sphere
from pytorch3d.vis.texture_vis import texturesuv_image_PIL from pytorch3d.vis.texture_vis import texturesuv_image_PIL
from .common_testing import get_pytorch3d_dir, get_tests_dir, TestCaseMixin from .common_testing import get_pytorch3d_dir, get_tests_dir, TestCaseMixin
@ -45,6 +46,12 @@ def _load(path, **kwargs) -> Meshes:
return io.load_mesh(path, **kwargs) return io.load_mesh(path, **kwargs)
def _write(mesh, path, **kwargs) -> bool:
io = IO()
io.register_meshes_format(MeshGlbFormat())
return io.save_mesh(mesh, path, **kwargs)
def _render( def _render(
mesh: Meshes, mesh: Meshes,
name: str, name: str,
@ -144,9 +151,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(mesh.faces_packed().shape, (5856, 3)) self.assertEqual(mesh.faces_packed().shape, (5856, 3))
self.assertEqual(mesh.verts_packed().shape, (3225, 3)) self.assertEqual(mesh.verts_packed().shape, (3225, 3))
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj") mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
self.assertClose( self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
)
self.assertClose( self.assertClose(
mesh.textures.verts_uvs_padded().cpu(), mesh_obj.textures.verts_uvs_padded() mesh.textures.verts_uvs_padded().cpu(), mesh_obj.textures.verts_uvs_padded()
@ -169,6 +174,69 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
self.assertClose(image, expected) self.assertClose(image, expected)
def test_save_cow(self):
"""
Save the cow mesh to a glb file
"""
# load cow mesh from a glb file
glb = DATA_DIR / "cow.glb"
self.assertTrue(glb.is_file())
device = torch.device("cuda:0")
mesh = _load(glb, device=device)
# save the mesh to a glb file
glb = DATA_DIR / "cow_write.glb"
_write(mesh, glb)
# load again
glb_reload = DATA_DIR / "cow_write.glb"
self.assertTrue(glb_reload.is_file())
device = torch.device("cuda:0")
mesh_reload = _load(glb_reload, device=device)
# assertions
self.assertEqual(mesh_reload.faces_packed().shape, (5856, 3))
self.assertEqual(mesh_reload.verts_packed().shape, (3225, 3))
self.assertClose(
mesh_reload.get_bounding_boxes().cpu(), mesh.get_bounding_boxes().cpu()
)
self.assertClose(
mesh_reload.textures.verts_uvs_padded().cpu(),
mesh.textures.verts_uvs_padded().cpu(),
)
self.assertClose(
mesh_reload.textures.faces_uvs_padded().cpu(),
mesh.textures.faces_uvs_padded().cpu(),
)
self.assertClose(
mesh_reload.textures.maps_padded().cpu(), mesh.textures.maps_padded().cpu()
)
def test_save_ico_sphere(self):
"""
save the ico_sphere mesh in a glb file
"""
ico_sphere_mesh = ico_sphere(level=3)
glb = DATA_DIR / "ico_sphere.glb"
_write(ico_sphere_mesh, glb)
# reload the ico_sphere
device = torch.device("cuda:0")
mesh_reload = _load(glb, device=device, include_textures=False)
self.assertClose(
ico_sphere_mesh.verts_padded().cpu(),
mesh_reload.verts_padded().cpu(),
)
self.assertClose(
ico_sphere_mesh.faces_padded().cpu(),
mesh_reload.faces_padded().cpu(),
)
def test_load_cow_no_texture(self): def test_load_cow_no_texture(self):
""" """
Load the cow as converted to a single mesh in a glb file. Load the cow as converted to a single mesh in a glb file.
@ -183,9 +251,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(mesh.faces_packed().shape, (5856, 3)) self.assertEqual(mesh.faces_packed().shape, (5856, 3))
self.assertEqual(mesh.verts_packed().shape, (3225, 3)) self.assertEqual(mesh.verts_packed().shape, (3225, 3))
mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj") mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj")
self.assertClose( self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes())
mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()
)
mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded())) mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))