diff --git a/pytorch3d/io/experimental_gltf_io.py b/pytorch3d/io/experimental_gltf_io.py index d59831aa..10905227 100644 --- a/pytorch3d/io/experimental_gltf_io.py +++ b/pytorch3d/io/experimental_gltf_io.py @@ -393,7 +393,7 @@ class _GLTFLoader: attributes = primitive["attributes"] vertex_colors = self._get_primitive_attribute(attributes, "COLOR_0", np.float32) if vertex_colors is not None: - return TexturesVertex(torch.from_numpy(vertex_colors)) + return TexturesVertex([torch.from_numpy(vertex_colors)]) vertex_texcoords_0 = self._get_primitive_attribute( attributes, "TEXCOORD_0", np.float32 @@ -559,12 +559,26 @@ class _GLTFWriter: meshes = defaultdict(list) # pyre-fixme[6]: Incompatible parameter type meshes["name"] = "Node-Mesh" - primitives = { - "attributes": {"POSITION": 0, "TEXCOORD_0": 2}, - "indices": 1, - "material": 0, # default material - "mode": _PrimitiveMode.TRIANGLES, - } + if isinstance(self.mesh.textures, TexturesVertex): + primitives = { + "attributes": {"POSITION": 0, "COLOR_0": 2}, + "indices": 1, + "mode": _PrimitiveMode.TRIANGLES, + } + elif isinstance(self.mesh.textures, TexturesUV): + primitives = { + "attributes": {"POSITION": 0, "TEXCOORD_0": 2}, + "indices": 1, + "mode": _PrimitiveMode.TRIANGLES, + "material": 0, + } + else: + primitives = { + "attributes": {"POSITION": 0}, + "indices": 1, + "mode": _PrimitiveMode.TRIANGLES, + } + meshes["primitives"].append(primitives) self._json_data["meshes"].append(meshes) @@ -610,6 +624,14 @@ class _GLTFWriter: element_min = list(map(float, np.min(data, axis=0))) element_max = list(map(float, np.max(data, axis=0))) byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + elif key == "texvertices": + component_type = _ComponentType.FLOAT + data = self.mesh.textures.verts_features_list()[0].cpu().numpy() + element_type = "VEC3" + buffer_view = 2 + element_min = list(map(float, np.min(data, axis=0))) + element_max = list(map(float, np.max(data, axis=0))) + byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] elif key == "indices": component_type = _ComponentType.UNSIGNED_SHORT data = ( @@ -646,8 +668,10 @@ class _GLTFWriter: return (byte_length, data) def _write_bufferview(self, key: str, **kwargs): - if key not in ["positions", "texcoords", "indices"]: - raise ValueError("key must be one of positions, texcoords or indices") + if key not in ["positions", "texcoords", "texvertices", "indices"]: + raise ValueError( + "key must be one of positions, texcoords, texvertices or indices" + ) bufferview = { "name": "bufferView_%s" % key, @@ -661,6 +685,10 @@ class _GLTFWriter: byte_per_element = 2 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] target = _TargetType.ARRAY_BUFFER bufferview["byteStride"] = int(byte_per_element) + elif key == "texvertices": + byte_per_element = 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.FLOAT]] + target = _TargetType.ELEMENT_ARRAY_BUFFER + bufferview["byteStride"] = int(byte_per_element) elif key == "indices": byte_per_element = ( 3 * _DTYPE_BYTES[_ITEM_TYPES[_ComponentType.UNSIGNED_SHORT]] @@ -701,12 +729,15 @@ class _GLTFWriter: pos_byte, pos_data = self._write_accessor_json("positions") idx_byte, idx_data = self._write_accessor_json("indices") include_textures = False - if ( - self.mesh.textures is not None - and self.mesh.textures.verts_uvs_list()[0] is not None - ): - tex_byte, tex_data = self._write_accessor_json("texcoords") - include_textures = True + if self.mesh.textures is not None: + if hasattr(self.mesh.textures, "verts_features_list"): + tex_byte, tex_data = self._write_accessor_json("texvertices") + include_textures = True + texcoords = False + elif self.mesh.textures.verts_uvs_list()[0] is not None: + tex_byte, tex_data = self._write_accessor_json("texcoords") + include_textures = True + texcoords = True # bufferViews for positions, texture coords and indices byte_offset = 0 @@ -717,17 +748,19 @@ class _GLTFWriter: byte_offset += idx_byte if include_textures: - self._write_bufferview( - "texcoords", byte_length=tex_byte, offset=byte_offset - ) + if texcoords: + self._write_bufferview( + "texcoords", byte_length=tex_byte, offset=byte_offset + ) + else: + self._write_bufferview( + "texvertices", byte_length=tex_byte, offset=byte_offset + ) byte_offset += tex_byte # image bufferView include_image = False - if ( - self.mesh.textures is not None - and self.mesh.textures.maps_list()[0] is not None - ): + if self.mesh.textures is not None and hasattr(self.mesh.textures, "maps_list"): include_image = True image_byte, image_data = self._write_image_buffer(offset=byte_offset) byte_offset += image_byte diff --git a/tests/test_io_gltf.py b/tests/test_io_gltf.py index b3b04dac..3b13c31c 100644 --- a/tests/test_io_gltf.py +++ b/tests/test_io_gltf.py @@ -120,6 +120,7 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase): The scene is "already lit", i.e. the textures reflect the lighting already, so we want to render them with full ambient light. """ + self.skipTest("Data not available") glb = DATA_DIR / "apartment_1.glb" @@ -266,3 +267,117 @@ class TestMeshGltfIO(TestCaseMixin, unittest.TestCase): expected = np.array(f) self.assertClose(image, expected) + + def test_load_save_load_cow_texturesvertex(self): + """ + Load the cow as converted to a single mesh in a glb file and then save it to a glb file. + """ + + glb = DATA_DIR / "cow.glb" + self.assertTrue(glb.is_file()) + device = torch.device("cuda:0") + mesh = _load(glb, device=device, include_textures=False) + self.assertEqual(len(mesh), 1) + self.assertIsNone(mesh.textures) + + self.assertEqual(mesh.faces_packed().shape, (5856, 3)) + self.assertEqual(mesh.verts_packed().shape, (3225, 3)) + mesh_obj = _load(TUTORIAL_DATA_DIR / "cow_mesh/cow.obj") + self.assertClose(mesh.get_bounding_boxes().cpu(), mesh_obj.get_bounding_boxes()) + + mesh.textures = TexturesVertex(0.5 * torch.ones_like(mesh.verts_padded())) + + image = _render(mesh, "cow_gray") + + with Image.open(DATA_DIR / "glb_cow_gray.png") as f: + expected = np.array(f) + + self.assertClose(image, expected) + + # save the mesh to a glb file + glb = DATA_DIR / "cow_write_texturesvertex.glb" + _write(mesh, glb) + + # reload the mesh glb file saved in TexturesVertex format + glb = DATA_DIR / "cow_write_texturesvertex.glb" + self.assertTrue(glb.is_file()) + mesh_dash = _load(glb, device=device) + self.assertEqual(len(mesh_dash), 1) + + self.assertEqual(mesh_dash.faces_packed().shape, (5856, 3)) + self.assertEqual(mesh_dash.verts_packed().shape, (3225, 3)) + self.assertEqual(mesh_dash.textures.verts_features_list()[0].shape, (3225, 3)) + + # check the re-rendered image with expected + image_dash = _render(mesh, "cow_gray_texturesvertex") + self.assertClose(image_dash, expected) + + def test_save_toy(self): + """ + Construct a simple mesh and save it to a glb file in TexturesVertex mode. + """ + + example = {} + example["POSITION"] = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [-1.0, 0.0, 0.0], + [-1.0, 0.0, 1.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + [-1.0, 1.0, 0.0], + [-1.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + ] + ] + ) + example["indices"] = torch.tensor( + [ + [ + [1, 4, 2], + [4, 3, 2], + [3, 7, 2], + [7, 6, 2], + [3, 4, 7], + [4, 8, 7], + [8, 5, 7], + [5, 6, 7], + [5, 2, 6], + [5, 1, 2], + [1, 5, 4], + [5, 8, 4], + ] + ] + ) + example["indices"] -= 1 + example["COLOR_0"] = torch.tensor( + [ + [ + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + ] + ] + ) + # example['prop'] = {'material': + # {'pbrMetallicRoughness': + # {'baseColorFactor': + # torch.tensor([[0.7, 0.7, 1, 0.5]]), + # 'metallicFactor': torch.tensor([1]), + # 'roughnessFactor': torch.tensor([0.1])}, + # 'alphaMode': 'BLEND', + # 'doubleSided': True}} + + texture = TexturesVertex(example["COLOR_0"]) + mesh = Meshes( + verts=example["POSITION"], faces=example["indices"], textures=texture + ) + + glb = DATA_DIR / "example_write_texturesvertex.glb" + _write(mesh, glb)