diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index be9b108e..bc2f5789 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -684,6 +684,8 @@ def save_obj( decimal_places: Optional[int] = None, path_manager: Optional[PathManager] = None, *, + normals: Optional[torch.Tensor] = None, + faces_normals_idx: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, texture_map: Optional[torch.Tensor] = None, @@ -698,6 +700,10 @@ def save_obj( decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. + normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx + to index into. + faces_normals_idx: LongTensor of shape (F, 3) giving the index into + normals for each vertex in the face. verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for each vertex in the face. @@ -713,6 +719,22 @@ def save_obj( message = "'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) + if (normals is None) != (faces_normals_idx is None): + message = "'normals' and 'faces_normals_idx' must both be None or neither." + raise ValueError(message) + + if faces_normals_idx is not None and ( + faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3 + ): + message = ( + "'faces_normals_idx' should either be empty or of shape (num_faces, 3)." + ) + raise ValueError(message) + + if normals is not None and (normals.dim() != 2 or normals.size(1) != 3): + message = "'normals' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." raise ValueError(message) @@ -742,9 +764,12 @@ def save_obj( verts, faces, decimal_places, + normals=normals, + faces_normals_idx=faces_normals_idx, verts_uvs=verts_uvs, faces_uvs=faces_uvs, save_texture=save_texture, + save_normals=normals is not None, ) # Save the .mtl and .png files associated with the texture @@ -777,9 +802,12 @@ def _save( faces, decimal_places: Optional[int] = None, *, + normals: Optional[torch.Tensor] = None, + faces_normals_idx: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, save_texture: bool = False, + save_normals: bool = False, ) -> None: if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): @@ -798,18 +826,26 @@ def _save( lines = "" - if len(verts): - if decimal_places is None: - float_str = "%f" - else: - float_str = "%" + ".%df" % decimal_places + if decimal_places is None: + float_str = "%f" + else: + float_str = "%" + ".%df" % decimal_places + if len(verts): V, D = verts.shape for i in range(V): vert = [float_str % verts[i, j] for j in range(D)] lines += "v %s\n" % " ".join(vert) + if save_normals: + assert normals is not None + assert faces_normals_idx is not None + lines += _write_normals(normals, faces_normals_idx, float_str) + if save_texture: + assert faces_uvs is not None + assert verts_uvs is not None + if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." raise ValueError(message) @@ -818,7 +854,6 @@ def _save( message = "'verts_uvs' should either be empty or of shape (num_verts, 2)." raise ValueError(message) - # pyre-fixme[16] # undefined attribute cpu verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu() # Save verts uvs after verts @@ -828,25 +863,77 @@ def _save( uv = [float_str % verts_uvs[i, j] for j in range(uD)] lines += "vt %s\n" % " ".join(uv) + f.write(lines) + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): warnings.warn("Faces have invalid indices") if len(faces): - F, P = faces.shape - for i in range(F): - if save_texture: - # Format faces as {verts_idx}/{verts_uvs_idx} + _write_faces( + f, + faces, + faces_uvs if save_texture else None, + faces_normals_idx if save_normals else None, + ) + + +def _write_normals( + normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str +) -> str: + if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3: + message = ( + "'faces_normals_idx' should either be empty or of shape (num_faces, 3)." + ) + raise ValueError(message) + + if normals.dim() != 2 or normals.size(1) != 3: + message = "'normals' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + + normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu() + + lines = [] + V, D = normals.shape + for i in range(V): + normal = [float_str % normals[i, j] for j in range(D)] + lines.append("vn %s\n" % " ".join(normal)) + return "".join(lines) + + +def _write_faces( + f, + faces: torch.Tensor, + faces_uvs: Optional[torch.Tensor], + faces_normals_idx: Optional[torch.Tensor], +) -> None: + F, P = faces.shape + for i in range(F): + if faces_normals_idx is not None: + if faces_uvs is not None: + # Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx} face = [ - "%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P) + "%d/%d/%d" + % ( + faces[i, j] + 1, + faces_uvs[i, j] + 1, + faces_normals_idx[i, j] + 1, + ) + for j in range(P) ] else: - face = ["%d" % (faces[i, j] + 1) for j in range(P)] + # Format faces as {verts_idx}//{verts_normals_idx} + face = [ + "%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1) + for j in range(P) + ] + elif faces_uvs is not None: + # Format faces as {verts_idx}/{verts_uvs_idx} + face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)] + else: + face = ["%d" % (faces[i, j] + 1) for j in range(P)] - if i + 1 < F: - lines += "f %s\n" % " ".join(face) - - elif i + 1 == F: - # No newline at the end of the file. - lines += "f %s" % " ".join(face) - - f.write(lines) + if i + 1 < F: + f.write("f %s\n" % " ".join(face)) + else: + # No newline at the end of the file. + f.write("f %s" % " ".join(face)) diff --git a/tests/test_io_obj.py b/tests/test_io_obj.py index 6b67932d..a6c7025d 100644 --- a/tests/test_io_obj.py +++ b/tests/test_io_obj.py @@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "f 4 2 1", ] ) - actual_file = open(Path(f.name), "r") - self.assertEqual(actual_file.read(), expected_file) + self.assertEqual(Path(f.name).read_text(), expected_file) def test_load_mtl(self): obj_filename = "cow_mesh/cow.obj" @@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "same type of texture"): join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) + def test_save_obj_with_normal(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + normals = torch.tensor( + [ + [0.02, 0.5, 0.73], + [0.3, 0.03, 0.361], + [0.32, 0.12, 0.47], + [0.36, 0.17, 0.9], + [0.40, 0.7, 0.19], + [1.0, 0.00, 0.000], + [0.00, 1.00, 0.00], + [0.00, 0.00, 1.0], + ], + dtype=torch.float32, + ) + faces_normals_idx = torch.tensor( + [[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64 + ) + + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + normals=normals, + faces_normals_idx=faces_normals_idx, + ) + + expected_obj_file = "\n".join( + [ + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vn 0.02 0.50 0.73", + "vn 0.30 0.03 0.36", + "vn 0.32 0.12 0.47", + "vn 0.36 0.17 0.90", + "vn 0.40 0.70 0.19", + "vn 1.00 0.00 0.00", + "vn 0.00 1.00 0.00", + "vn 0.00 0.00 1.00", + "f 1//1 3//2 2//3", + "f 1//3 2//4 3//5", + "f 4//5 3//6 2//7", + "f 4//7 2//8 1//1", + ] + ) + + # Check the obj file is saved correctly + with open(obj_file, "r") as actual_file: + self.assertEqual(actual_file.read(), expected_obj_file) + def test_save_obj_with_texture(self): verts = torch.tensor( [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], @@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) # Check the obj file is saved correctly - actual_file = open(obj_file, "r") - self.assertEqual(actual_file.read(), expected_obj_file) + with open(obj_file, "r") as actual_file: + self.assertEqual(actual_file.read(), expected_obj_file) # Check the mtl file is saved correctly mtl_file_name = os.path.join(temp_dir, "mesh.mtl") - mtl_file = open(mtl_file_name, "r") - self.assertEqual(mtl_file.read(), expected_mtl_file) + with open(mtl_file_name, "r") as mtl_file: + self.assertEqual(mtl_file.read(), expected_mtl_file) + + # Check the texture image file is saved correctly + texture_image = load_rgb_image("mesh.png", temp_dir) + self.assertClose(texture_image, texture_map) + + def test_save_obj_with_normal_and_texture(self): + verts = torch.tensor( + [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ) + faces = torch.tensor( + [[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64 + ) + normals = torch.tensor( + [ + [0.02, 0.5, 0.73], + [0.3, 0.03, 0.361], + [0.32, 0.12, 0.47], + [0.36, 0.17, 0.9], + ], + dtype=torch.float32, + ) + faces_normals_idx = faces + verts_uvs = torch.tensor( + [[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]], + dtype=torch.float32, + ) + faces_uvs = faces + texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0 + + with TemporaryDirectory() as temp_dir: + obj_file = os.path.join(temp_dir, "mesh.obj") + save_obj( + obj_file, + verts, + faces, + decimal_places=2, + normals=normals, + faces_normals_idx=faces_normals_idx, + verts_uvs=verts_uvs, + faces_uvs=faces_uvs, + texture_map=texture_map, + ) + + expected_obj_file = "\n".join( + [ + "", + "mtllib mesh.mtl", + "usemtl mesh", + "", + "v 0.01 0.20 0.30", + "v 0.20 0.03 0.41", + "v 0.30 0.40 0.05", + "v 0.60 0.70 0.80", + "vn 0.02 0.50 0.73", + "vn 0.30 0.03 0.36", + "vn 0.32 0.12 0.47", + "vn 0.36 0.17 0.90", + "vt 0.02 0.50", + "vt 0.30 0.03", + "vt 0.32 0.12", + "vt 0.36 0.17", + "f 1/1/1 3/3/3 2/2/2", + "f 1/1/1 2/2/2 3/3/3", + "f 4/4/4 3/3/3 2/2/2", + "f 4/4/4 2/2/2 1/1/1", + ] + ) + expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""]) + + # Check there are only 3 files in the temp dir + tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"] + tempfiles_dir = os.listdir(temp_dir) + self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) + + # Check the obj file is saved correctly + with open(obj_file, "r") as actual_file: + self.assertEqual(actual_file.read(), expected_obj_file) + + # Check the mtl file is saved correctly + mtl_file_name = os.path.join(temp_dir, "mesh.mtl") + with open(mtl_file_name, "r") as mtl_file: + self.assertEqual(mtl_file.read(), expected_mtl_file) # Check the texture image file is saved correctly texture_image = load_rgb_image("mesh.png", temp_dir) @@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertEqual(tempfiles, tempfiles_dir) # Check the obj file is saved correctly - actual_file = open(obj_file, "r") - self.assertEqual(actual_file.read(), expected_obj_file) + with open(obj_file, "r") as actual_file: + self.assertEqual(actual_file.read(), expected_obj_file) obj_file = StringIO() with self.assertRaises(ValueError): @@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) # Check the obj file is saved correctly - actual_file = open(obj_file, "r") - self.assertEqual(actual_file.read(), expected_obj_file) + with open(obj_file, "r") as actual_file: + self.assertEqual(actual_file.read(), expected_obj_file) # Check the mtl file is saved correctly mtl_file_name = os.path.join(temp_dir, "mesh.mtl") - mtl_file = open(mtl_file_name, "r") - self.assertEqual(mtl_file.read(), expected_mtl_file) + with open(mtl_file_name, "r") as mtl_file: + self.assertEqual(mtl_file.read(), expected_mtl_file) # Check the texture image file is saved correctly texture_image = load_rgb_image("mesh.png", temp_dir)