diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 651b22c5..c555b917 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -15,6 +15,17 @@ from PIL import Image from pytorch3d.structures import Meshes, Textures, join_meshes +def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor: + """ + Return a 2D tensor with the specified cols and dtype filled with data, + even when data is empty. + """ + if not data: + return torch.zeros((0, cols), dtype=dtype) + + return torch.tensor(data, dtype=dtype) + + def _read_image(file_name: str, format=None): """ Read an image from a file using Pillow. @@ -61,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index): Raises: ValueError if indices are not in a valid range. """ - faces_indices = torch.tensor(faces_indices, dtype=torch.int64) + faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64) # Change to 0 based indexing. faces_indices[(faces_indices > 0)] -= 1 @@ -70,10 +81,8 @@ def _format_faces_indices(faces_indices, max_index): faces_indices[(faces_indices < 0)] += max_index # Check indices are valid. - if not ( - torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0) - ): - raise ValueError("Faces have invalid indices.") + if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0): + warnings.warn("Faces have invalid indices") return faces_indices @@ -333,7 +342,7 @@ def _load(f_obj, data_dir, load_textures=True): # startswith expects each line to be a string. If the file is read in as # bytes then first decode to strings. - if isinstance(lines[0], bytes): + if lines and isinstance(lines[0], bytes): lines = [l.decode("utf-8") for l in lines] for line in lines: @@ -380,9 +389,9 @@ def _load(f_obj, data_dir, load_textures=True): faces_materials_idx, ) - verts = torch.tensor(verts) # (V, 3) - normals = torch.tensor(normals) # (N, 3) - verts_uvs = torch.tensor(verts_uvs) # (T, 3) + verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3) + normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3) + verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2) faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0]) @@ -544,30 +553,46 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None): # TODO (nikhilar) Speed up this function. def _save(f, verts, faces, decimal_places: Optional[int] = None): - if verts.dim() != 2 or verts.size(1) != 3: - raise ValueError("Argument 'verts' should be of shape (num_verts, 3).") - if faces.dim() != 2 or faces.size(1) != 3: - raise ValueError("Argument 'faces' should be of shape (num_faces, 3).") + if not (len(verts) or len(faces)): + warnings.warn("Empty 'verts' and 'faces' arguments provided") + return + + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): + raise ValueError( + "Argument 'verts' should either be empty or of shape (num_verts, 3)." + ) + + if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3): + raise ValueError( + "Argument 'faces' should either be empty or of shape (num_faces, 3)." + ) + verts, faces = verts.cpu(), faces.cpu() - if decimal_places is None: - float_str = "%f" - else: - float_str = "%" + ".%df" % decimal_places - lines = "" - V, D = verts.shape - for i in range(V): - vert = [float_str % verts[i, j] for j in range(D)] - lines += "v %s\n" % " ".join(vert) - F, P = faces.shape - for i in range(F): - face = ["%d" % (faces[i, j] + 1) for j in range(P)] - if i + 1 < F: - lines += "f %s\n" % " ".join(face) - elif i + 1 == F: - # No newline at the end of the file. - lines += "f %s" % " ".join(face) + if len(verts): + if decimal_places is None: + float_str = "%f" + else: + float_str = "%" + ".%df" % decimal_places + + V, D = verts.shape + for i in range(V): + vert = [float_str % verts[i, j] for j in range(D)] + lines += "v %s\n" % " ".join(vert) + + if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): + warnings.warn("Faces have invalid indices") + + if len(faces): + F, P = faces.shape + for i in range(F): + face = ["%d" % (faces[i, j] + 1) for j in range(P)] + if i + 1 < F: + lines += "f %s\n" % " ".join(face) + elif i + 1 == F: + # No newline at the end of the file. + lines += "f %s" % " ".join(face) f.write(lines) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 5f00f543..3e6d7ef7 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -277,9 +277,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ) obj_file = StringIO(obj_file) - with self.assertRaises(ValueError) as err: + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): load_obj(obj_file) - self.assertTrue("Faces have invalid indices." in str(err.exception)) def test_load_obj_error_invalid_normal_indices(self): obj_file = "\n".join( @@ -295,9 +294,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ) obj_file = StringIO(obj_file) - with self.assertRaises(ValueError) as err: + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): load_obj(obj_file) - self.assertTrue("Faces have invalid indices." in str(err.exception)) def test_load_obj_error_invalid_texture_indices(self): obj_file = "\n".join( @@ -313,9 +311,87 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ) obj_file = StringIO(obj_file) - with self.assertRaises(ValueError) as err: + with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"): load_obj(obj_file) - self.assertTrue("Faces have invalid indices." in str(err.exception)) + + def test_save_obj_invalid_shapes(self): + # Invalid vertices shape + with self.assertRaises(ValueError) as error: + verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4) + faces = torch.LongTensor([[0, 1, 2]]) + save_obj(StringIO(), verts, faces) + expected_message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." + self.assertTrue(expected_message, error.exception) + + # Invalid faces shape + with self.assertRaises(ValueError) as error: + verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4) + save_obj(StringIO(), verts, faces) + expected_message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." + self.assertTrue(expected_message, error.exception) + + def test_save_obj_invalid_indices(self): + message_regex = "Faces have invalid indices" + verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([[0, 1, 2]]) + with self.assertWarnsRegex(UserWarning, message_regex): + save_obj(StringIO(), verts, faces) + + faces = torch.LongTensor([[-1, 0, 1]]) + with self.assertWarnsRegex(UserWarning, message_regex): + save_obj(StringIO(), verts, faces) + + def _test_save_load(self, verts, faces): + f = StringIO() + save_obj(f, verts, faces) + f.seek(0) + expected_verts, expected_faces = verts, faces + if not len(expected_verts): # Always compare with a (V, 3) tensor + expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32) + if not len(expected_faces): # Always compare with an (F, 3) tensor + expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64) + actual_verts, actual_faces, _ = load_obj(f) + self.assertClose(expected_verts, actual_verts) + self.assertClose(expected_faces, actual_faces.verts_idx) + + def test_empty_save_load_obj(self): + # Vertices + empty faces + verts = torch.FloatTensor([[0.1, 0.2, 0.3]]) + faces = torch.LongTensor([]) + self._test_save_load(verts, faces) + + faces = torch.zeros(size=(0, 3), dtype=torch.int64) + self._test_save_load(verts, faces) + + # Faces + empty vertices + message_regex = "Faces have invalid indices" + verts = torch.FloatTensor([]) + faces = torch.LongTensor([[0, 1, 2]]) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts, faces) + + verts = torch.zeros(size=(0, 3), dtype=torch.float32) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts, faces) + + # Empty vertices + empty faces + message_regex = "Empty 'verts' and 'faces' arguments provided" + verts0 = torch.FloatTensor([]) + faces0 = torch.LongTensor([]) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts0, faces0) + + faces3 = torch.zeros(size=(0, 3), dtype=torch.int64) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts0, faces3) + + verts3 = torch.zeros(size=(0, 3), dtype=torch.float32) + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts3, faces0) + + with self.assertWarnsRegex(UserWarning, message_regex): + self._test_save_load(verts3, faces3) def test_save_obj(self): verts = torch.tensor( @@ -410,7 +486,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ] ) obj_file = StringIO(obj_file) - with self.assertWarnsRegex(Warning, "No mtl file provided"): + with self.assertWarnsRegex(UserWarning, "No mtl file provided"): verts, faces, aux = load_obj(obj_file) expected_verts = torch.tensor( @@ -434,7 +510,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): DATA_DIR = Path(__file__).resolve().parent / "data" obj_filename = "missing_files_obj/model.obj" filename = os.path.join(DATA_DIR, obj_filename) - with self.assertWarnsRegex(Warning, "Texture file does not exist"): + with self.assertWarnsRegex(UserWarning, "Texture file does not exist"): verts, faces, aux = load_obj(filename) expected_verts = torch.tensor( @@ -475,7 +551,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): DATA_DIR = Path(__file__).resolve().parent / "data" obj_filename = "missing_files_obj/model2.obj" filename = os.path.join(DATA_DIR, obj_filename) - with self.assertWarnsRegex(Warning, "Mtl file does not exist"): + with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"): verts, faces, aux = load_obj(filename) expected_verts = torch.tensor(