diff --git a/pytorch3d/io/experimental_gltf_io.py b/pytorch3d/io/experimental_gltf_io.py index 1dd3d164..ecfe95c4 100644 --- a/pytorch3d/io/experimental_gltf_io.py +++ b/pytorch3d/io/experimental_gltf_io.py @@ -39,7 +39,7 @@ import json import struct import warnings from base64 import b64decode -from collections import deque +from collections import defaultdict, deque from enum import IntEnum from io import BytesIO from typing import Any, BinaryIO, cast, Dict, List, Optional, Tuple, Union @@ -102,6 +102,27 @@ _ELEMENT_SHAPES: Dict[str, _ElementShape] = { "MAT4": (4, 4), } +_DTYPE_BYTES: Dict[Any, int] = { + np.int8: 1, + np.uint8: 1, + np.int16: 2, + np.uint16: 2, + np.uint32: 4, + np.float32: 4, +} + + +class _TargetType(IntEnum): + ARRAY_BUFFER = 34962 + ELEMENT_ARRAY_BUFFER = 34963 + + +class OurEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.int64): + return str(obj) + return super(OurEncoder, self).default(obj) + def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]: header = stream.read(12) @@ -109,7 +130,6 @@ def _read_header(stream: BinaryIO) -> Optional[Tuple[int, int]]: if magic != _GLTF_MAGIC: return None - return version, length @@ -227,7 +247,6 @@ class _GLTFLoader: offset = buffer_view.get("byteOffset", 0) binary_data = self.get_binary_data(buffer_view["buffer"]) - bytesio = BytesIO(binary_data[offset : offset + length].tobytes()) with Image.open(bytesio) as f: array = np.array(f) @@ -521,6 +540,223 @@ def load_meshes( return names_meshes_list +class _GLTFWriter: + def __init__(self, data: Meshes, buffer_stream: BinaryIO) -> None: + self._json_data = defaultdict(list) + self.mesh = data + self.buffer_stream = buffer_stream + + # initialize json with one scene and one node + scene_index = 0 + # pyre-fixme[6]: Incompatible parameter type + self._json_data["scene"] = scene_index + self._json_data["scenes"].append({"nodes": [scene_index]}) + self._json_data["asset"] = {"version": "2.0"} + node = {"name": "Node", "mesh": 0} + self._json_data["nodes"].append(node) + + # mesh primitives + meshes = defaultdict(list) + # pyre-fixme[6]: Incompatible parameter type + meshes["name"] = "Node-Mesh" + primitives = { + "attributes": {"POSITION": 0, "TEXCOORD_0": 2}, + "indices": 1, + "material": 0, # default material + "mode": _PrimitiveMode.TRIANGLES, + } + meshes["primitives"].append(primitives) + self._json_data["meshes"].append(meshes) + + # default material + material = { + "name": "material_1", + "pbrMetallicRoughness": { + "baseColorTexture": {"index": 0}, + "baseColorFactor": [1, 1, 1, 1], + "metallicFactor": 0, + "roughnessFactor": 0.99, + }, + "emissiveFactor": [0, 0, 0], + "alphaMode": "OPAQUE", + } + self._json_data["materials"].append(material) + + # default sampler + sampler = {"magFilter": 9729, "minFilter": 9986, "wrapS": 10497, "wrapT": 10497} + self._json_data["samplers"].append(sampler) + + # default textures + texture = {"sampler": 0, "source": 0} + self._json_data["textures"].append(texture) + + def _write_accessor_json(self, key: str) -> Tuple[int, np.ndarray]: + name = "Node-Mesh_%s" % key + byte_offset = 0 + if key == "positions": + data = self.mesh.verts_packed().cpu().numpy() + component_type = _ComponentType.FLOAT + element_type = "VEC3" + buffer_view = 0 + element_min = list(map(float, np.min(data, axis=0))) + element_max = list(map(float, np.max(data, axis=0))) + byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + elif key == "texcoords": + component_type = _ComponentType.FLOAT + data = self.mesh.textures.verts_uvs_list()[0].cpu().numpy() + data[:, 1] = 1 - data[:, -1] # flip y tex-coordinate + element_type = "VEC2" + buffer_view = 2 + element_min = list(map(float, np.min(data, axis=0))) + element_max = list(map(float, np.max(data, axis=0))) + byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + elif key == "indices": + component_type = _ComponentType.UNSIGNED_SHORT + data = ( + self.mesh.faces_packed() + .cpu() + .numpy() + .astype(_ITEM_TYPES[component_type]) + ) + element_type = "SCALAR" + buffer_view = 1 + element_min = list(map(int, np.min(data, keepdims=True))) + element_max = list(map(int, np.max(data, keepdims=True))) + byte_per_element = ( + 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]] + ) + else: + raise NotImplementedError( + "invalid key accessor, should be one of positions, indices or texcoords" + ) + + count = int(data.shape[0]) + byte_length = count * byte_per_element + accessor_json = { + "name": name, + "componentType": component_type, + "type": element_type, + "bufferView": buffer_view, + "byteOffset": byte_offset, + "min": element_min, + "max": element_max, + "count": count * 3 if key == "indices" else count, + } + self._json_data["accessors"].append(accessor_json) + return (byte_length, data) + + def _write_bufferview(self, key: str, **kwargs): + if key not in ["positions", "texcoords", "indices"]: + raise ValueError("key must be one of positions, texcoords or indices") + + bufferview = { + "name": "bufferView_%s" % key, + "buffer": 0, + } + target = _TargetType.ARRAY_BUFFER + if key == "positions": + byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + bufferview["byteStride"] = int(byte_per_element) + elif key == "texcoords": + byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + target = _TargetType.ARRAY_BUFFER + bufferview["byteStride"] = int(byte_per_element) + elif key == "indices": + byte_per_element = ( + 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]] + ) + target = _TargetType.ELEMENT_ARRAY_BUFFER + + bufferview["target"] = target + bufferview["byteOffset"] = kwargs.get("offset") + bufferview["byteLength"] = kwargs.get("byte_length") + self._json_data["bufferViews"].append(bufferview) + + def _write_image_buffer(self, **kwargs) -> Tuple[int, bytes]: + image_np = self.mesh.textures.maps_list()[0].cpu().numpy() + image_array = (image_np * 255.0).astype(np.uint8) + im = Image.fromarray(image_array) + with BytesIO() as f: + im.save(f, format="PNG") + image_data = f.getvalue() + + image_data_byte_length = len(image_data) + bufferview_image = { + "buffer": 0, + } + bufferview_image["byteOffset"] = kwargs.get("offset") + bufferview_image["byteLength"] = image_data_byte_length + self._json_data["bufferViews"].append(bufferview_image) + + image = {"name": "texture", "mimeType": "image/png", "bufferView": 3} + self._json_data["images"].append(image) + return (image_data_byte_length, image_data) + + def save(self): + # check validity of mesh + if self.mesh.verts_packed() is None or self.mesh.faces_packed() is None: + raise ValueError("invalid mesh to save, verts or face indices are empty") + + # accessors for positions, texture uvs and face indices + pos_byte, pos_data = self._write_accessor_json("positions") + idx_byte, idx_data = self._write_accessor_json("indices") + include_textures = False + if ( + self.mesh.textures is not None + and self.mesh.textures.verts_uvs_list()[0] is not None + ): + tex_byte, tex_data = self._write_accessor_json("texcoords") + include_textures = True + + # bufferViews for positions, texture coords and indices + byte_offset = 0 + self._write_bufferview("positions", byte_length=pos_byte, offset=byte_offset) + byte_offset += pos_byte + + self._write_bufferview("indices", byte_length=idx_byte, offset=byte_offset) + byte_offset += idx_byte + + if include_textures: + self._write_bufferview( + "texcoords", byte_length=tex_byte, offset=byte_offset + ) + byte_offset += tex_byte + + # image bufferView + include_image = False + if ( + self.mesh.textures is not None + and self.mesh.textures.maps_list()[0] is not None + ): + include_image = True + image_byte, image_data = self._write_image_buffer(offset=byte_offset) + byte_offset += image_byte + + # buffers + self._json_data["buffers"].append({"byteLength": int(byte_offset)}) + + # organize into a glb + json_bytes = bytes(json.dumps(self._json_data, cls=OurEncoder), "utf-8") + json_length = len(json_bytes) + + # write header + header = struct.pack(" bool: - return False + """ + Writes all the meshes from the default scene to GLB file. + + Args: + data: meshes to save + path: path of the GLB file to write into + path_manager: PathManager object for interpreting the path + + Return True if saving succeeds and False otherwise + """ + + if not endswith(path, self.known_suffixes): + return False + + with _open_file(path, path_manager, "wb") as f: + writer = _GLTFWriter(data, cast(BinaryIO, f)) + writer.save() + return True diff --git a/tests/test_io_gltf.py b/tests/test_io_gltf.py index 709c24dd..8c7de0dd 100644 --- a/tests/test_io_gltf.py +++ b/tests/test_io_gltf.py @@ -29,6 +29,7 @@ from pytorch3d.renderer.mesh import ( ) from pytorch3d.structures import Meshes from pytorch3d.transforms import axis_angle_to_matrix +from pytorch3d.utils import ico_sphere from pytorch3d.vis.texture_vis import texturesuv_image_PIL from .common_testing import get_pytorch3d_dir, get_tests_dir, TestCaseMixin @@ -45,6 +46,12 @@ def _load(path, **kwargs) -> Meshes: return io.load_mesh(path, **kwargs) +def _write(mesh, path, **kwargs) -> bool: + io = IO() + io.register_meshes_format(MeshGlbFormat()) + return io.save_mesh(mesh, path, **kwargs) + + def _render( mesh: Meshes, name: str, @@ -144,9 +151,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase): self.assertEqual(mesh.faces_packed().shape, (5856, 3)) self.assertEqual(mesh.verts_packed().shape, (3225, 3)) mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj") - self.assertClose( - mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes() - ) + self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()) self.assertClose( mesh.textures.verts_uvs_padded().cpu(), mesh_obj.textures.verts_uvs_padded() @@ -169,6 +174,69 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase): self.assertClose(image, expected) + def test_save_cow(self): + """ + Save the cow mesh to a glb file + """ + # load cow mesh from a glb file + glb = DATA_DIR / "cow.glb" + self.assertTrue(glb.is_file()) + device = torch.device("cuda:0") + mesh = _load(glb, device=device) + + # save the mesh to a glb file + glb = DATA_DIR / "cow_write.glb" + _write(mesh, glb) + + # load again + glb_reload = DATA_DIR / "cow_write.glb" + self.assertTrue(glb_reload.is_file()) + device = torch.device("cuda:0") + mesh_reload = _load(glb_reload, device=device) + + # assertions + self.assertEqual(mesh_reload.faces_packed().shape, (5856, 3)) + self.assertEqual(mesh_reload.verts_packed().shape, (3225, 3)) + self.assertClose( + mesh_reload.get_bounding_boxes().cpu(), mesh.get_bounding_boxes().cpu() + ) + + self.assertClose( + mesh_reload.textures.verts_uvs_padded().cpu(), + mesh.textures.verts_uvs_padded().cpu(), + ) + + self.assertClose( + mesh_reload.textures.faces_uvs_padded().cpu(), + mesh.textures.faces_uvs_padded().cpu(), + ) + + self.assertClose( + mesh_reload.textures.maps_padded().cpu(), mesh.textures.maps_padded().cpu() + ) + + def test_save_ico_sphere(self): + """ + save the ico_sphere mesh in a glb file + """ + ico_sphere_mesh = ico_sphere(level=3) + glb = DATA_DIR / "ico_sphere.glb" + _write(ico_sphere_mesh, glb) + + # reload the ico_sphere + device = torch.device("cuda:0") + mesh_reload = _load(glb, device=device, include_textures=False) + + self.assertClose( + ico_sphere_mesh.verts_padded().cpu(), + mesh_reload.verts_padded().cpu(), + ) + + self.assertClose( + ico_sphere_mesh.faces_padded().cpu(), + mesh_reload.faces_padded().cpu(), + ) + def test_load_cow_no_texture(self): """ Load the cow as converted to a single mesh in a glb file. @@ -183,9 +251,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase): self.assertEqual(mesh.faces_packed().shape, (5856, 3)) self.assertEqual(mesh.verts_packed().shape, (3225, 3)) mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj") - self.assertClose( - mesh_obj.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes() - ) + self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()) mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded()))