mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	include TexturesUV in IO.save_mesh(x.obj)
Summary: Added export of UV textures to IO.save_mesh in Pytorch3d MeshObjFormat now passes verts_uv, faces_uv, and texture_map as input to save_obj TODO: check if TexturesUV.verts_uv_list or TexturesUV.verts_uv_padded() should be passed to save_obj IO.save_mesh(obj_file, meshes, decimal_places=2) should be IO().save_mesh(obj_file, meshes, decimal_places=2) Reviewed By: bottler Differential Revision: D39617441 fbshipit-source-id: 4628b7f26f70e38c65f235852b990c8edb0ded23
This commit is contained in:
		
							parent
							
								
									305cf32f6b
								
							
						
					
					
						commit
						6ae6ff9cf7
					
				@ -334,12 +334,25 @@ class MeshObjFormat(MeshFormatInterpreter):
 | 
			
		||||
 | 
			
		||||
        verts = data.verts_list()[0]
 | 
			
		||||
        faces = data.faces_list()[0]
 | 
			
		||||
 | 
			
		||||
        verts_uvs: Optional[torch.Tensor] = None
 | 
			
		||||
        faces_uvs: Optional[torch.Tensor] = None
 | 
			
		||||
        texture_map: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
        if isinstance(data.textures, TexturesUV):
 | 
			
		||||
            verts_uvs = data.textures.verts_uvs_padded()[0]
 | 
			
		||||
            faces_uvs = data.textures.faces_uvs_padded()[0]
 | 
			
		||||
            texture_map = data.textures.maps_padded()[0]
 | 
			
		||||
 | 
			
		||||
        save_obj(
 | 
			
		||||
            f=path,
 | 
			
		||||
            verts=verts,
 | 
			
		||||
            faces=faces,
 | 
			
		||||
            decimal_places=decimal_places,
 | 
			
		||||
            path_manager=path_manager,
 | 
			
		||||
            verts_uvs=verts_uvs,
 | 
			
		||||
            faces_uvs=faces_uvs,
 | 
			
		||||
            texture_map=texture_map,
 | 
			
		||||
        )
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1050,6 +1050,68 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                texture_map=texture_map[..., 1],  # Incorrect shape
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_save_obj_with_texture_IO(self):
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces = torch.tensor(
 | 
			
		||||
            [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
 | 
			
		||||
        )
 | 
			
		||||
        verts_uvs = torch.tensor(
 | 
			
		||||
            [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces_uvs = faces
 | 
			
		||||
        texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
 | 
			
		||||
 | 
			
		||||
        with TemporaryDirectory() as temp_dir:
 | 
			
		||||
            obj_file = os.path.join(temp_dir, "mesh.obj")
 | 
			
		||||
            textures_uv = TexturesUV([texture_map], [faces_uvs], [verts_uvs])
 | 
			
		||||
            test_mesh = Meshes(verts=[verts], faces=[faces], textures=textures_uv)
 | 
			
		||||
 | 
			
		||||
            IO().save_mesh(data=test_mesh, path=obj_file, decimal_places=2)
 | 
			
		||||
 | 
			
		||||
            expected_obj_file = "\n".join(
 | 
			
		||||
                [
 | 
			
		||||
                    "",
 | 
			
		||||
                    "mtllib mesh.mtl",
 | 
			
		||||
                    "usemtl mesh",
 | 
			
		||||
                    "",
 | 
			
		||||
                    "v 0.01 0.20 0.30",
 | 
			
		||||
                    "v 0.20 0.03 0.41",
 | 
			
		||||
                    "v 0.30 0.40 0.05",
 | 
			
		||||
                    "v 0.60 0.70 0.80",
 | 
			
		||||
                    "vt 0.02 0.50",
 | 
			
		||||
                    "vt 0.30 0.03",
 | 
			
		||||
                    "vt 0.32 0.12",
 | 
			
		||||
                    "vt 0.36 0.17",
 | 
			
		||||
                    "f 1/1 3/3 2/2",
 | 
			
		||||
                    "f 1/1 2/2 3/3",
 | 
			
		||||
                    "f 4/4 3/3 2/2",
 | 
			
		||||
                    "f 4/4 2/2 1/1",
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
 | 
			
		||||
 | 
			
		||||
            # Check there are only 3 files in the temp dir
 | 
			
		||||
            tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
 | 
			
		||||
            tempfiles_dir = os.listdir(temp_dir)
 | 
			
		||||
            self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
 | 
			
		||||
 | 
			
		||||
            # Check the obj file is saved correctly
 | 
			
		||||
            actual_file = open(obj_file, "r")
 | 
			
		||||
            self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
            # Check the mtl file is saved correctly
 | 
			
		||||
            mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
 | 
			
		||||
            mtl_file = open(mtl_file_name, "r")
 | 
			
		||||
            self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
 | 
			
		||||
            # Check the texture image file is saved correctly
 | 
			
		||||
            texture_image = load_rgb_image("mesh.png", temp_dir)
 | 
			
		||||
            self.assertClose(texture_image, texture_map)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
 | 
			
		||||
        return lambda: save_obj(StringIO(), verts, faces, decimal_places)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user