Fix saving / loading empty OBJ files

Summary:
OBJ files without vertices or faces should be allowed:
- an OBJ with only vertices can represent a point cloud
- an OBJ without any vertex or face is just empty
- an OBJ with faces referencing inexistent vertices has invalid data

Reviewed By: gkioxari

Differential Revision: D20392526

fbshipit-source-id: e72c846ff1e5787fb11d527af3fefa261f9eb0ee
This commit is contained in:
Patrick Labatut
2020-03-28 08:12:07 -07:00
committed by Facebook GitHub Bot
parent 870290df34
commit 3061c5b663
2 changed files with 140 additions and 39 deletions

View File

@@ -15,6 +15,17 @@ from PIL import Image
from pytorch3d.structures import Meshes, Textures, join_meshes
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
"""
Return a 2D tensor with the specified cols and dtype filled with data,
even when data is empty.
"""
if not data:
return torch.zeros((0, cols), dtype=dtype)
return torch.tensor(data, dtype=dtype)
def _read_image(file_name: str, format=None):
"""
Read an image from a file using Pillow.
@@ -61,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index):
Raises:
ValueError if indices are not in a valid range.
"""
faces_indices = torch.tensor(faces_indices, dtype=torch.int64)
faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64)
# Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1
@@ -70,10 +81,8 @@ def _format_faces_indices(faces_indices, max_index):
faces_indices[(faces_indices < 0)] += max_index
# Check indices are valid.
if not (
torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0)
):
raise ValueError("Faces have invalid indices.")
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
warnings.warn("Faces have invalid indices")
return faces_indices
@@ -333,7 +342,7 @@ def _load(f_obj, data_dir, load_textures=True):
# startswith expects each line to be a string. If the file is read in as
# bytes then first decode to strings.
if isinstance(lines[0], bytes):
if lines and isinstance(lines[0], bytes):
lines = [l.decode("utf-8") for l in lines]
for line in lines:
@@ -380,9 +389,9 @@ def _load(f_obj, data_dir, load_textures=True):
faces_materials_idx,
)
verts = torch.tensor(verts) # (V, 3)
normals = torch.tensor(normals) # (N, 3)
verts_uvs = torch.tensor(verts_uvs) # (T, 3)
verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3)
normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3)
verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2)
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
@@ -544,30 +553,46 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None):
if verts.dim() != 2 or verts.size(1) != 3:
raise ValueError("Argument 'verts' should be of shape (num_verts, 3).")
if faces.dim() != 2 or faces.size(1) != 3:
raise ValueError("Argument 'faces' should be of shape (num_faces, 3).")
if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
return
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
raise ValueError(
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
)
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
raise ValueError(
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
)
verts, faces = verts.cpu(), faces.cpu()
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
lines = ""
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
if len(verts):
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")
if len(faces):
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
f.write(lines)