Fix saving / loading empty OBJ files

Summary:
OBJ files without vertices or faces should be allowed:
- an OBJ with only vertices can represent a point cloud
- an OBJ without any vertex or face is just empty
- an OBJ with faces referencing inexistent vertices has invalid data

Reviewed By: gkioxari

Differential Revision: D20392526

fbshipit-source-id: e72c846ff1e5787fb11d527af3fefa261f9eb0ee
This commit is contained in:
Patrick Labatut 2020-03-28 08:12:07 -07:00 committed by Facebook GitHub Bot
parent 870290df34
commit 3061c5b663
2 changed files with 140 additions and 39 deletions

View File

@ -15,6 +15,17 @@ from PIL import Image
from pytorch3d.structures import Meshes, Textures, join_meshes
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not data:
return torch.zeros((0, cols), dtype=dtype)
return torch.tensor(data, dtype=dtype)
def _read_image(file_name: str, format=None):
"""
Read an image from a file using Pillow.
@ -61,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index):
Raises:
ValueError if indices are not in a valid range.
"""
faces_indices = torch.tensor(faces_indices, dtype=torch.int64)
faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64)
# Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1
@ -70,10 +81,8 @@ def _format_faces_indices(faces_indices, max_index):
faces_indices[(faces_indices < 0)] += max_index
# Check indices are valid.
if not (
torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0)
):
raise ValueError("Faces have invalid indices.")
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
warnings.warn("Faces have invalid indices")
return faces_indices
@ -333,7 +342,7 @@ def _load(f_obj, data_dir, load_textures=True):
# startswith expects each line to be a string. If the file is read in as
# bytes then first decode to strings.
if isinstance(lines[0], bytes):
if lines and isinstance(lines[0], bytes):
lines = [l.decode("utf-8") for l in lines]
for line in lines:
@ -380,9 +389,9 @@ def _load(f_obj, data_dir, load_textures=True):
faces_materials_idx,
)
verts = torch.tensor(verts) # (V, 3)
normals = torch.tensor(normals) # (N, 3)
verts_uvs = torch.tensor(verts_uvs) # (T, 3)
verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3)
normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3)
verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2)
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
@ -544,30 +553,46 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None):
if verts.dim() != 2 or verts.size(1) != 3:
raise ValueError("Argument 'verts' should be of shape (num_verts, 3).")
if faces.dim() != 2 or faces.size(1) != 3:
raise ValueError("Argument 'faces' should be of shape (num_faces, 3).")
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
raise ValueError(
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
verts, faces = verts.cpu(), faces.cpu()
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
lines = ""
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
if len(verts):
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
if len(faces):
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
f.write(lines)

View File

@ -277,9 +277,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
)
obj_file = StringIO(obj_file)
with self.assertRaises(ValueError) as err:
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
self.assertTrue("Faces have invalid indices." in str(err.exception))
def test_load_obj_error_invalid_normal_indices(self):
obj_file = "\n".join(
@ -295,9 +294,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
)
obj_file = StringIO(obj_file)
with self.assertRaises(ValueError) as err:
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
self.assertTrue("Faces have invalid indices." in str(err.exception))
def test_load_obj_error_invalid_texture_indices(self):
obj_file = "\n".join(
@ -313,9 +311,87 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
)
obj_file = StringIO(obj_file)
with self.assertRaises(ValueError) as err:
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
load_obj(obj_file)
self.assertTrue("Faces have invalid indices." in str(err.exception))
def test_save_obj_invalid_shapes(self):
# Invalid vertices shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
faces = torch.LongTensor([[0, 1, 2]])
save_obj(StringIO(), verts, faces)
expected_message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
self.assertTrue(expected_message, error.exception)
# Invalid faces shape
with self.assertRaises(ValueError) as error:
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
save_obj(StringIO(), verts, faces)
expected_message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
self.assertTrue(expected_message, error.exception)
def test_save_obj_invalid_indices(self):
message_regex = "Faces have invalid indices"
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces)
faces = torch.LongTensor([[-1, 0, 1]])
with self.assertWarnsRegex(UserWarning, message_regex):
save_obj(StringIO(), verts, faces)
def _test_save_load(self, verts, faces):
f = StringIO()
save_obj(f, verts, faces)
f.seek(0)
expected_verts, expected_faces = verts, faces
if not len(expected_verts): # Always compare with a (V, 3) tensor
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
if not len(expected_faces): # Always compare with an (F, 3) tensor
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
actual_verts, actual_faces, _ = load_obj(f)
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces.verts_idx)
def test_empty_save_load_obj(self):
# Vertices + empty faces
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
faces = torch.LongTensor([])
self._test_save_load(verts, faces)
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
self._test_save_load(verts, faces)
# Faces + empty vertices
message_regex = "Faces have invalid indices"
verts = torch.FloatTensor([])
faces = torch.LongTensor([[0, 1, 2]])
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts, faces)
verts = torch.zeros(size=(0, 3), dtype=torch.float32)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts, faces)
# Empty vertices + empty faces
message_regex = "Empty 'verts' and 'faces' arguments provided"
verts0 = torch.FloatTensor([])
faces0 = torch.LongTensor([])
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts0, faces0)
faces3 = torch.zeros(size=(0, 3), dtype=torch.int64)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts0, faces3)
verts3 = torch.zeros(size=(0, 3), dtype=torch.float32)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts3, faces0)
with self.assertWarnsRegex(UserWarning, message_regex):
self._test_save_load(verts3, faces3)
def test_save_obj(self):
verts = torch.tensor(
@ -410,7 +486,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
]
)
obj_file = StringIO(obj_file)
with self.assertWarnsRegex(Warning, "No mtl file provided"):
with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
verts, faces, aux = load_obj(obj_file)
expected_verts = torch.tensor(
@ -434,7 +510,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
DATA_DIR = Path(__file__).resolve().parent / "data"
obj_filename = "missing_files_obj/model.obj"
filename = os.path.join(DATA_DIR, obj_filename)
with self.assertWarnsRegex(Warning, "Texture file does not exist"):
with self.assertWarnsRegex(UserWarning, "Texture file does not exist"):
verts, faces, aux = load_obj(filename)
expected_verts = torch.tensor(
@ -475,7 +551,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
DATA_DIR = Path(__file__).resolve().parent / "data"
obj_filename = "missing_files_obj/model2.obj"
filename = os.path.join(DATA_DIR, obj_filename)
with self.assertWarnsRegex(Warning, "Mtl file does not exist"):
with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"):
verts, faces, aux = load_obj(filename)
expected_verts = torch.tensor(