mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	allow saving vertex normal in save_obj (#1511)
Summary:
Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`.
This patch fixes this by allowing passing per-vertex normal data in `save_obj`:
``` python
def save_obj(
    f: PathOrStr,
    verts,
    faces,
    decimal_places: Optional[int] = None,
    path_manager: Optional[PathManager] = None,
    *,
    verts_normals: Optional[torch.Tensor] = None,
    faces_normals: Optional[torch.Tensor] = None,
    verts_uvs: Optional[torch.Tensor] = None,
    faces_uvs: Optional[torch.Tensor] = None,
    texture_map: Optional[torch.Tensor] = None,
) -> None:
    """
    Save a mesh to an .obj file.
    Args:
        f: File (str or path) to which the mesh should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shape (F, 3) giving faces.
        decimal_places: Number of decimal places for saving.
        path_manager: Optional PathManager for interpreting f if
            it is a str.
        verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
        faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
            for each vertex in the face.
        verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
        faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
            each vertex in the face.
        texture_map: FloatTensor of shape (H, W, 3) representing the texture map
            for the mesh which will be saved as an image. The values are expected
            to be in the range [0, 1],
    """
```
Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511
Reviewed By: shapovalov
Differential Revision: D45086045
Pulled By: bottler
fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
			
			
This commit is contained in:
		
							parent
							
								
									ec87284c4b
								
							
						
					
					
						commit
						092400f1e7
					
				@ -684,6 +684,8 @@ def save_obj(
 | 
			
		||||
    decimal_places: Optional[int] = None,
 | 
			
		||||
    path_manager: Optional[PathManager] = None,
 | 
			
		||||
    *,
 | 
			
		||||
    normals: Optional[torch.Tensor] = None,
 | 
			
		||||
    faces_normals_idx: Optional[torch.Tensor] = None,
 | 
			
		||||
    verts_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
    faces_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
    texture_map: Optional[torch.Tensor] = None,
 | 
			
		||||
@ -698,6 +700,10 @@ def save_obj(
 | 
			
		||||
        decimal_places: Number of decimal places for saving.
 | 
			
		||||
        path_manager: Optional PathManager for interpreting f if
 | 
			
		||||
            it is a str.
 | 
			
		||||
        normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
 | 
			
		||||
            to index into.
 | 
			
		||||
        faces_normals_idx: LongTensor of shape (F, 3) giving the index into
 | 
			
		||||
            normals for each vertex in the face.
 | 
			
		||||
        verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
 | 
			
		||||
        faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
 | 
			
		||||
            each vertex in the face.
 | 
			
		||||
@ -713,6 +719,22 @@ def save_obj(
 | 
			
		||||
        message = "'faces' should either be empty or of shape (num_faces, 3)."
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    if (normals is None) != (faces_normals_idx is None):
 | 
			
		||||
        message = "'normals' and 'faces_normals_idx' must both be None or neither."
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    if faces_normals_idx is not None and (
 | 
			
		||||
        faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
 | 
			
		||||
    ):
 | 
			
		||||
        message = (
 | 
			
		||||
            "'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
 | 
			
		||||
        )
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
 | 
			
		||||
        message = "'normals' should either be empty or of shape (num_verts, 3)."
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
 | 
			
		||||
        message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
@ -742,9 +764,12 @@ def save_obj(
 | 
			
		||||
            verts,
 | 
			
		||||
            faces,
 | 
			
		||||
            decimal_places,
 | 
			
		||||
            normals=normals,
 | 
			
		||||
            faces_normals_idx=faces_normals_idx,
 | 
			
		||||
            verts_uvs=verts_uvs,
 | 
			
		||||
            faces_uvs=faces_uvs,
 | 
			
		||||
            save_texture=save_texture,
 | 
			
		||||
            save_normals=normals is not None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Save the .mtl and .png files associated with the texture
 | 
			
		||||
@ -777,9 +802,12 @@ def _save(
 | 
			
		||||
    faces,
 | 
			
		||||
    decimal_places: Optional[int] = None,
 | 
			
		||||
    *,
 | 
			
		||||
    normals: Optional[torch.Tensor] = None,
 | 
			
		||||
    faces_normals_idx: Optional[torch.Tensor] = None,
 | 
			
		||||
    verts_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
    faces_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
    save_texture: bool = False,
 | 
			
		||||
    save_normals: bool = False,
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
    if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
 | 
			
		||||
@ -798,18 +826,26 @@ def _save(
 | 
			
		||||
 | 
			
		||||
    lines = ""
 | 
			
		||||
 | 
			
		||||
    if len(verts):
 | 
			
		||||
        if decimal_places is None:
 | 
			
		||||
            float_str = "%f"
 | 
			
		||||
        else:
 | 
			
		||||
            float_str = "%" + ".%df" % decimal_places
 | 
			
		||||
    if decimal_places is None:
 | 
			
		||||
        float_str = "%f"
 | 
			
		||||
    else:
 | 
			
		||||
        float_str = "%" + ".%df" % decimal_places
 | 
			
		||||
 | 
			
		||||
    if len(verts):
 | 
			
		||||
        V, D = verts.shape
 | 
			
		||||
        for i in range(V):
 | 
			
		||||
            vert = [float_str % verts[i, j] for j in range(D)]
 | 
			
		||||
            lines += "v %s\n" % " ".join(vert)
 | 
			
		||||
 | 
			
		||||
    if save_normals:
 | 
			
		||||
        assert normals is not None
 | 
			
		||||
        assert faces_normals_idx is not None
 | 
			
		||||
        lines += _write_normals(normals, faces_normals_idx, float_str)
 | 
			
		||||
 | 
			
		||||
    if save_texture:
 | 
			
		||||
        assert faces_uvs is not None
 | 
			
		||||
        assert verts_uvs is not None
 | 
			
		||||
 | 
			
		||||
        if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
 | 
			
		||||
            message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
 | 
			
		||||
            raise ValueError(message)
 | 
			
		||||
@ -818,7 +854,6 @@ def _save(
 | 
			
		||||
            message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
 | 
			
		||||
            raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[16] # undefined attribute cpu
 | 
			
		||||
        verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
 | 
			
		||||
 | 
			
		||||
        # Save verts uvs after verts
 | 
			
		||||
@ -828,25 +863,77 @@ def _save(
 | 
			
		||||
                uv = [float_str % verts_uvs[i, j] for j in range(uD)]
 | 
			
		||||
                lines += "vt %s\n" % " ".join(uv)
 | 
			
		||||
 | 
			
		||||
    f.write(lines)
 | 
			
		||||
 | 
			
		||||
    if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
 | 
			
		||||
        warnings.warn("Faces have invalid indices")
 | 
			
		||||
 | 
			
		||||
    if len(faces):
 | 
			
		||||
        F, P = faces.shape
 | 
			
		||||
        for i in range(F):
 | 
			
		||||
            if save_texture:
 | 
			
		||||
                # Format faces as {verts_idx}/{verts_uvs_idx}
 | 
			
		||||
        _write_faces(
 | 
			
		||||
            f,
 | 
			
		||||
            faces,
 | 
			
		||||
            faces_uvs if save_texture else None,
 | 
			
		||||
            faces_normals_idx if save_normals else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _write_normals(
 | 
			
		||||
    normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
 | 
			
		||||
) -> str:
 | 
			
		||||
    if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
 | 
			
		||||
        message = (
 | 
			
		||||
            "'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
 | 
			
		||||
        )
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    if normals.dim() != 2 or normals.size(1) != 3:
 | 
			
		||||
        message = "'normals' should either be empty or of shape (num_verts, 3)."
 | 
			
		||||
        raise ValueError(message)
 | 
			
		||||
 | 
			
		||||
    normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
 | 
			
		||||
 | 
			
		||||
    lines = []
 | 
			
		||||
    V, D = normals.shape
 | 
			
		||||
    for i in range(V):
 | 
			
		||||
        normal = [float_str % normals[i, j] for j in range(D)]
 | 
			
		||||
        lines.append("vn %s\n" % " ".join(normal))
 | 
			
		||||
    return "".join(lines)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _write_faces(
 | 
			
		||||
    f,
 | 
			
		||||
    faces: torch.Tensor,
 | 
			
		||||
    faces_uvs: Optional[torch.Tensor],
 | 
			
		||||
    faces_normals_idx: Optional[torch.Tensor],
 | 
			
		||||
) -> None:
 | 
			
		||||
    F, P = faces.shape
 | 
			
		||||
    for i in range(F):
 | 
			
		||||
        if faces_normals_idx is not None:
 | 
			
		||||
            if faces_uvs is not None:
 | 
			
		||||
                # Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
 | 
			
		||||
                face = [
 | 
			
		||||
                    "%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
 | 
			
		||||
                    "%d/%d/%d"
 | 
			
		||||
                    % (
 | 
			
		||||
                        faces[i, j] + 1,
 | 
			
		||||
                        faces_uvs[i, j] + 1,
 | 
			
		||||
                        faces_normals_idx[i, j] + 1,
 | 
			
		||||
                    )
 | 
			
		||||
                    for j in range(P)
 | 
			
		||||
                ]
 | 
			
		||||
            else:
 | 
			
		||||
                face = ["%d" % (faces[i, j] + 1) for j in range(P)]
 | 
			
		||||
                # Format faces as {verts_idx}//{verts_normals_idx}
 | 
			
		||||
                face = [
 | 
			
		||||
                    "%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
 | 
			
		||||
                    for j in range(P)
 | 
			
		||||
                ]
 | 
			
		||||
        elif faces_uvs is not None:
 | 
			
		||||
            # Format faces as {verts_idx}/{verts_uvs_idx}
 | 
			
		||||
            face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
 | 
			
		||||
        else:
 | 
			
		||||
            face = ["%d" % (faces[i, j] + 1) for j in range(P)]
 | 
			
		||||
 | 
			
		||||
            if i + 1 < F:
 | 
			
		||||
                lines += "f %s\n" % " ".join(face)
 | 
			
		||||
 | 
			
		||||
            elif i + 1 == F:
 | 
			
		||||
                # No newline at the end of the file.
 | 
			
		||||
                lines += "f %s" % " ".join(face)
 | 
			
		||||
 | 
			
		||||
    f.write(lines)
 | 
			
		||||
        if i + 1 < F:
 | 
			
		||||
            f.write("f %s\n" % " ".join(face))
 | 
			
		||||
        else:
 | 
			
		||||
            # No newline at the end of the file.
 | 
			
		||||
            f.write("f %s" % " ".join(face))
 | 
			
		||||
 | 
			
		||||
@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                    "f 4 2 1",
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            actual_file = open(Path(f.name), "r")
 | 
			
		||||
            self.assertEqual(actual_file.read(), expected_file)
 | 
			
		||||
            self.assertEqual(Path(f.name).read_text(), expected_file)
 | 
			
		||||
 | 
			
		||||
    def test_load_mtl(self):
 | 
			
		||||
        obj_filename = "cow_mesh/cow.obj"
 | 
			
		||||
@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "same type of texture"):
 | 
			
		||||
            join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
 | 
			
		||||
 | 
			
		||||
    def test_save_obj_with_normal(self):
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces = torch.tensor(
 | 
			
		||||
            [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
 | 
			
		||||
        )
 | 
			
		||||
        normals = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.02, 0.5, 0.73],
 | 
			
		||||
                [0.3, 0.03, 0.361],
 | 
			
		||||
                [0.32, 0.12, 0.47],
 | 
			
		||||
                [0.36, 0.17, 0.9],
 | 
			
		||||
                [0.40, 0.7, 0.19],
 | 
			
		||||
                [1.0, 0.00, 0.000],
 | 
			
		||||
                [0.00, 1.00, 0.00],
 | 
			
		||||
                [0.00, 0.00, 1.0],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces_normals_idx = torch.tensor(
 | 
			
		||||
            [[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with TemporaryDirectory() as temp_dir:
 | 
			
		||||
            obj_file = os.path.join(temp_dir, "mesh.obj")
 | 
			
		||||
            save_obj(
 | 
			
		||||
                obj_file,
 | 
			
		||||
                verts,
 | 
			
		||||
                faces,
 | 
			
		||||
                decimal_places=2,
 | 
			
		||||
                normals=normals,
 | 
			
		||||
                faces_normals_idx=faces_normals_idx,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            expected_obj_file = "\n".join(
 | 
			
		||||
                [
 | 
			
		||||
                    "v 0.01 0.20 0.30",
 | 
			
		||||
                    "v 0.20 0.03 0.41",
 | 
			
		||||
                    "v 0.30 0.40 0.05",
 | 
			
		||||
                    "v 0.60 0.70 0.80",
 | 
			
		||||
                    "vn 0.02 0.50 0.73",
 | 
			
		||||
                    "vn 0.30 0.03 0.36",
 | 
			
		||||
                    "vn 0.32 0.12 0.47",
 | 
			
		||||
                    "vn 0.36 0.17 0.90",
 | 
			
		||||
                    "vn 0.40 0.70 0.19",
 | 
			
		||||
                    "vn 1.00 0.00 0.00",
 | 
			
		||||
                    "vn 0.00 1.00 0.00",
 | 
			
		||||
                    "vn 0.00 0.00 1.00",
 | 
			
		||||
                    "f 1//1 3//2 2//3",
 | 
			
		||||
                    "f 1//3 2//4 3//5",
 | 
			
		||||
                    "f 4//5 3//6 2//7",
 | 
			
		||||
                    "f 4//7 2//8 1//1",
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Check the obj file is saved correctly
 | 
			
		||||
            with open(obj_file, "r") as actual_file:
 | 
			
		||||
                self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
    def test_save_obj_with_texture(self):
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
 | 
			
		||||
@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
 | 
			
		||||
 | 
			
		||||
            # Check the obj file is saved correctly
 | 
			
		||||
            actual_file = open(obj_file, "r")
 | 
			
		||||
            self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
            with open(obj_file, "r") as actual_file:
 | 
			
		||||
                self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
            # Check the mtl file is saved correctly
 | 
			
		||||
            mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
 | 
			
		||||
            mtl_file = open(mtl_file_name, "r")
 | 
			
		||||
            self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
            with open(mtl_file_name, "r") as mtl_file:
 | 
			
		||||
                self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
 | 
			
		||||
            # Check the texture image file is saved correctly
 | 
			
		||||
            texture_image = load_rgb_image("mesh.png", temp_dir)
 | 
			
		||||
            self.assertClose(texture_image, texture_map)
 | 
			
		||||
 | 
			
		||||
    def test_save_obj_with_normal_and_texture(self):
 | 
			
		||||
        verts = torch.tensor(
 | 
			
		||||
            [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces = torch.tensor(
 | 
			
		||||
            [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
 | 
			
		||||
        )
 | 
			
		||||
        normals = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.02, 0.5, 0.73],
 | 
			
		||||
                [0.3, 0.03, 0.361],
 | 
			
		||||
                [0.32, 0.12, 0.47],
 | 
			
		||||
                [0.36, 0.17, 0.9],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces_normals_idx = faces
 | 
			
		||||
        verts_uvs = torch.tensor(
 | 
			
		||||
            [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        faces_uvs = faces
 | 
			
		||||
        texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
 | 
			
		||||
 | 
			
		||||
        with TemporaryDirectory() as temp_dir:
 | 
			
		||||
            obj_file = os.path.join(temp_dir, "mesh.obj")
 | 
			
		||||
            save_obj(
 | 
			
		||||
                obj_file,
 | 
			
		||||
                verts,
 | 
			
		||||
                faces,
 | 
			
		||||
                decimal_places=2,
 | 
			
		||||
                normals=normals,
 | 
			
		||||
                faces_normals_idx=faces_normals_idx,
 | 
			
		||||
                verts_uvs=verts_uvs,
 | 
			
		||||
                faces_uvs=faces_uvs,
 | 
			
		||||
                texture_map=texture_map,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            expected_obj_file = "\n".join(
 | 
			
		||||
                [
 | 
			
		||||
                    "",
 | 
			
		||||
                    "mtllib mesh.mtl",
 | 
			
		||||
                    "usemtl mesh",
 | 
			
		||||
                    "",
 | 
			
		||||
                    "v 0.01 0.20 0.30",
 | 
			
		||||
                    "v 0.20 0.03 0.41",
 | 
			
		||||
                    "v 0.30 0.40 0.05",
 | 
			
		||||
                    "v 0.60 0.70 0.80",
 | 
			
		||||
                    "vn 0.02 0.50 0.73",
 | 
			
		||||
                    "vn 0.30 0.03 0.36",
 | 
			
		||||
                    "vn 0.32 0.12 0.47",
 | 
			
		||||
                    "vn 0.36 0.17 0.90",
 | 
			
		||||
                    "vt 0.02 0.50",
 | 
			
		||||
                    "vt 0.30 0.03",
 | 
			
		||||
                    "vt 0.32 0.12",
 | 
			
		||||
                    "vt 0.36 0.17",
 | 
			
		||||
                    "f 1/1/1 3/3/3 2/2/2",
 | 
			
		||||
                    "f 1/1/1 2/2/2 3/3/3",
 | 
			
		||||
                    "f 4/4/4 3/3/3 2/2/2",
 | 
			
		||||
                    "f 4/4/4 2/2/2 1/1/1",
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
            expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
 | 
			
		||||
 | 
			
		||||
            # Check there are only 3 files in the temp dir
 | 
			
		||||
            tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
 | 
			
		||||
            tempfiles_dir = os.listdir(temp_dir)
 | 
			
		||||
            self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
 | 
			
		||||
 | 
			
		||||
            # Check the obj file is saved correctly
 | 
			
		||||
            with open(obj_file, "r") as actual_file:
 | 
			
		||||
                self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
            # Check the mtl file is saved correctly
 | 
			
		||||
            mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
 | 
			
		||||
            with open(mtl_file_name, "r") as mtl_file:
 | 
			
		||||
                self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
 | 
			
		||||
            # Check the texture image file is saved correctly
 | 
			
		||||
            texture_image = load_rgb_image("mesh.png", temp_dir)
 | 
			
		||||
@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                self.assertEqual(tempfiles, tempfiles_dir)
 | 
			
		||||
 | 
			
		||||
                # Check the obj file is saved correctly
 | 
			
		||||
                actual_file = open(obj_file, "r")
 | 
			
		||||
                self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
                with open(obj_file, "r") as actual_file:
 | 
			
		||||
                    self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
        obj_file = StringIO()
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
 | 
			
		||||
 | 
			
		||||
            # Check the obj file is saved correctly
 | 
			
		||||
            actual_file = open(obj_file, "r")
 | 
			
		||||
            self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
            with open(obj_file, "r") as actual_file:
 | 
			
		||||
                self.assertEqual(actual_file.read(), expected_obj_file)
 | 
			
		||||
 | 
			
		||||
            # Check the mtl file is saved correctly
 | 
			
		||||
            mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
 | 
			
		||||
            mtl_file = open(mtl_file_name, "r")
 | 
			
		||||
            self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
            with open(mtl_file_name, "r") as mtl_file:
 | 
			
		||||
                self.assertEqual(mtl_file.read(), expected_mtl_file)
 | 
			
		||||
 | 
			
		||||
            # Check the texture image file is saved correctly
 | 
			
		||||
            texture_image = load_rgb_image("mesh.png", temp_dir)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user