mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Fix saving / loading empty OBJ files
Summary: OBJ files without vertices or faces should be allowed: - an OBJ with only vertices can represent a point cloud - an OBJ without any vertex or face is just empty - an OBJ with faces referencing inexistent vertices has invalid data Reviewed By: gkioxari Differential Revision: D20392526 fbshipit-source-id: e72c846ff1e5787fb11d527af3fefa261f9eb0ee
This commit is contained in:
parent
870290df34
commit
3061c5b663
@ -15,6 +15,17 @@ from PIL import Image
|
|||||||
from pytorch3d.structures import Meshes, Textures, join_meshes
|
from pytorch3d.structures import Meshes, Textures, join_meshes
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||||
|
even when data is empty.
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return torch.zeros((0, cols), dtype=dtype)
|
||||||
|
|
||||||
|
return torch.tensor(data, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def _read_image(file_name: str, format=None):
|
def _read_image(file_name: str, format=None):
|
||||||
"""
|
"""
|
||||||
Read an image from a file using Pillow.
|
Read an image from a file using Pillow.
|
||||||
@ -61,7 +72,7 @@ def _format_faces_indices(faces_indices, max_index):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError if indices are not in a valid range.
|
ValueError if indices are not in a valid range.
|
||||||
"""
|
"""
|
||||||
faces_indices = torch.tensor(faces_indices, dtype=torch.int64)
|
faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64)
|
||||||
|
|
||||||
# Change to 0 based indexing.
|
# Change to 0 based indexing.
|
||||||
faces_indices[(faces_indices > 0)] -= 1
|
faces_indices[(faces_indices > 0)] -= 1
|
||||||
@ -70,10 +81,8 @@ def _format_faces_indices(faces_indices, max_index):
|
|||||||
faces_indices[(faces_indices < 0)] += max_index
|
faces_indices[(faces_indices < 0)] += max_index
|
||||||
|
|
||||||
# Check indices are valid.
|
# Check indices are valid.
|
||||||
if not (
|
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
|
||||||
torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0)
|
warnings.warn("Faces have invalid indices")
|
||||||
):
|
|
||||||
raise ValueError("Faces have invalid indices.")
|
|
||||||
|
|
||||||
return faces_indices
|
return faces_indices
|
||||||
|
|
||||||
@ -333,7 +342,7 @@ def _load(f_obj, data_dir, load_textures=True):
|
|||||||
|
|
||||||
# startswith expects each line to be a string. If the file is read in as
|
# startswith expects each line to be a string. If the file is read in as
|
||||||
# bytes then first decode to strings.
|
# bytes then first decode to strings.
|
||||||
if isinstance(lines[0], bytes):
|
if lines and isinstance(lines[0], bytes):
|
||||||
lines = [l.decode("utf-8") for l in lines]
|
lines = [l.decode("utf-8") for l in lines]
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
@ -380,9 +389,9 @@ def _load(f_obj, data_dir, load_textures=True):
|
|||||||
faces_materials_idx,
|
faces_materials_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
verts = torch.tensor(verts) # (V, 3)
|
verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3)
|
||||||
normals = torch.tensor(normals) # (N, 3)
|
normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3)
|
||||||
verts_uvs = torch.tensor(verts_uvs) # (T, 3)
|
verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2)
|
||||||
|
|
||||||
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
|
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
|
||||||
|
|
||||||
@ -544,30 +553,46 @@ def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
|||||||
|
|
||||||
# TODO (nikhilar) Speed up this function.
|
# TODO (nikhilar) Speed up this function.
|
||||||
def _save(f, verts, faces, decimal_places: Optional[int] = None):
|
def _save(f, verts, faces, decimal_places: Optional[int] = None):
|
||||||
if verts.dim() != 2 or verts.size(1) != 3:
|
if not (len(verts) or len(faces)):
|
||||||
raise ValueError("Argument 'verts' should be of shape (num_verts, 3).")
|
warnings.warn("Empty 'verts' and 'faces' arguments provided")
|
||||||
if faces.dim() != 2 or faces.size(1) != 3:
|
return
|
||||||
raise ValueError("Argument 'faces' should be of shape (num_faces, 3).")
|
|
||||||
|
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||||
|
raise ValueError(
|
||||||
|
"Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
|
||||||
|
raise ValueError(
|
||||||
|
"Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||||
|
)
|
||||||
|
|
||||||
verts, faces = verts.cpu(), faces.cpu()
|
verts, faces = verts.cpu(), faces.cpu()
|
||||||
|
|
||||||
if decimal_places is None:
|
|
||||||
float_str = "%f"
|
|
||||||
else:
|
|
||||||
float_str = "%" + ".%df" % decimal_places
|
|
||||||
|
|
||||||
lines = ""
|
lines = ""
|
||||||
V, D = verts.shape
|
|
||||||
for i in range(V):
|
|
||||||
vert = [float_str % verts[i, j] for j in range(D)]
|
|
||||||
lines += "v %s\n" % " ".join(vert)
|
|
||||||
|
|
||||||
F, P = faces.shape
|
if len(verts):
|
||||||
for i in range(F):
|
if decimal_places is None:
|
||||||
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
float_str = "%f"
|
||||||
if i + 1 < F:
|
else:
|
||||||
lines += "f %s\n" % " ".join(face)
|
float_str = "%" + ".%df" % decimal_places
|
||||||
elif i + 1 == F:
|
|
||||||
# No newline at the end of the file.
|
V, D = verts.shape
|
||||||
lines += "f %s" % " ".join(face)
|
for i in range(V):
|
||||||
|
vert = [float_str % verts[i, j] for j in range(D)]
|
||||||
|
lines += "v %s\n" % " ".join(vert)
|
||||||
|
|
||||||
|
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||||
|
warnings.warn("Faces have invalid indices")
|
||||||
|
|
||||||
|
if len(faces):
|
||||||
|
F, P = faces.shape
|
||||||
|
for i in range(F):
|
||||||
|
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
||||||
|
if i + 1 < F:
|
||||||
|
lines += "f %s\n" % " ".join(face)
|
||||||
|
elif i + 1 == F:
|
||||||
|
# No newline at the end of the file.
|
||||||
|
lines += "f %s" % " ".join(face)
|
||||||
|
|
||||||
f.write(lines)
|
f.write(lines)
|
||||||
|
@ -277,9 +277,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
obj_file = StringIO(obj_file)
|
obj_file = StringIO(obj_file)
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as err:
|
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
|
||||||
load_obj(obj_file)
|
load_obj(obj_file)
|
||||||
self.assertTrue("Faces have invalid indices." in str(err.exception))
|
|
||||||
|
|
||||||
def test_load_obj_error_invalid_normal_indices(self):
|
def test_load_obj_error_invalid_normal_indices(self):
|
||||||
obj_file = "\n".join(
|
obj_file = "\n".join(
|
||||||
@ -295,9 +294,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
obj_file = StringIO(obj_file)
|
obj_file = StringIO(obj_file)
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as err:
|
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
|
||||||
load_obj(obj_file)
|
load_obj(obj_file)
|
||||||
self.assertTrue("Faces have invalid indices." in str(err.exception))
|
|
||||||
|
|
||||||
def test_load_obj_error_invalid_texture_indices(self):
|
def test_load_obj_error_invalid_texture_indices(self):
|
||||||
obj_file = "\n".join(
|
obj_file = "\n".join(
|
||||||
@ -313,9 +311,87 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
obj_file = StringIO(obj_file)
|
obj_file = StringIO(obj_file)
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as err:
|
with self.assertWarnsRegex(UserWarning, "Faces have invalid indices"):
|
||||||
load_obj(obj_file)
|
load_obj(obj_file)
|
||||||
self.assertTrue("Faces have invalid indices." in str(err.exception))
|
|
||||||
|
def test_save_obj_invalid_shapes(self):
|
||||||
|
# Invalid vertices shape
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
verts = torch.FloatTensor([[0.1, 0.2, 0.3, 0.4]]) # (V, 4)
|
||||||
|
faces = torch.LongTensor([[0, 1, 2]])
|
||||||
|
save_obj(StringIO(), verts, faces)
|
||||||
|
expected_message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||||
|
self.assertTrue(expected_message, error.exception)
|
||||||
|
|
||||||
|
# Invalid faces shape
|
||||||
|
with self.assertRaises(ValueError) as error:
|
||||||
|
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||||
|
faces = torch.LongTensor([[0, 1, 2, 3]]) # (F, 4)
|
||||||
|
save_obj(StringIO(), verts, faces)
|
||||||
|
expected_message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||||
|
self.assertTrue(expected_message, error.exception)
|
||||||
|
|
||||||
|
def test_save_obj_invalid_indices(self):
|
||||||
|
message_regex = "Faces have invalid indices"
|
||||||
|
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||||
|
faces = torch.LongTensor([[0, 1, 2]])
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
save_obj(StringIO(), verts, faces)
|
||||||
|
|
||||||
|
faces = torch.LongTensor([[-1, 0, 1]])
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
save_obj(StringIO(), verts, faces)
|
||||||
|
|
||||||
|
def _test_save_load(self, verts, faces):
|
||||||
|
f = StringIO()
|
||||||
|
save_obj(f, verts, faces)
|
||||||
|
f.seek(0)
|
||||||
|
expected_verts, expected_faces = verts, faces
|
||||||
|
if not len(expected_verts): # Always compare with a (V, 3) tensor
|
||||||
|
expected_verts = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||||
|
if not len(expected_faces): # Always compare with an (F, 3) tensor
|
||||||
|
expected_faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||||
|
actual_verts, actual_faces, _ = load_obj(f)
|
||||||
|
self.assertClose(expected_verts, actual_verts)
|
||||||
|
self.assertClose(expected_faces, actual_faces.verts_idx)
|
||||||
|
|
||||||
|
def test_empty_save_load_obj(self):
|
||||||
|
# Vertices + empty faces
|
||||||
|
verts = torch.FloatTensor([[0.1, 0.2, 0.3]])
|
||||||
|
faces = torch.LongTensor([])
|
||||||
|
self._test_save_load(verts, faces)
|
||||||
|
|
||||||
|
faces = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||||
|
self._test_save_load(verts, faces)
|
||||||
|
|
||||||
|
# Faces + empty vertices
|
||||||
|
message_regex = "Faces have invalid indices"
|
||||||
|
verts = torch.FloatTensor([])
|
||||||
|
faces = torch.LongTensor([[0, 1, 2]])
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts, faces)
|
||||||
|
|
||||||
|
verts = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts, faces)
|
||||||
|
|
||||||
|
# Empty vertices + empty faces
|
||||||
|
message_regex = "Empty 'verts' and 'faces' arguments provided"
|
||||||
|
verts0 = torch.FloatTensor([])
|
||||||
|
faces0 = torch.LongTensor([])
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts0, faces0)
|
||||||
|
|
||||||
|
faces3 = torch.zeros(size=(0, 3), dtype=torch.int64)
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts0, faces3)
|
||||||
|
|
||||||
|
verts3 = torch.zeros(size=(0, 3), dtype=torch.float32)
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts3, faces0)
|
||||||
|
|
||||||
|
with self.assertWarnsRegex(UserWarning, message_regex):
|
||||||
|
self._test_save_load(verts3, faces3)
|
||||||
|
|
||||||
def test_save_obj(self):
|
def test_save_obj(self):
|
||||||
verts = torch.tensor(
|
verts = torch.tensor(
|
||||||
@ -410,7 +486,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
obj_file = StringIO(obj_file)
|
obj_file = StringIO(obj_file)
|
||||||
with self.assertWarnsRegex(Warning, "No mtl file provided"):
|
with self.assertWarnsRegex(UserWarning, "No mtl file provided"):
|
||||||
verts, faces, aux = load_obj(obj_file)
|
verts, faces, aux = load_obj(obj_file)
|
||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
@ -434,7 +510,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
obj_filename = "missing_files_obj/model.obj"
|
obj_filename = "missing_files_obj/model.obj"
|
||||||
filename = os.path.join(DATA_DIR, obj_filename)
|
filename = os.path.join(DATA_DIR, obj_filename)
|
||||||
with self.assertWarnsRegex(Warning, "Texture file does not exist"):
|
with self.assertWarnsRegex(UserWarning, "Texture file does not exist"):
|
||||||
verts, faces, aux = load_obj(filename)
|
verts, faces, aux = load_obj(filename)
|
||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
@ -475,7 +551,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
obj_filename = "missing_files_obj/model2.obj"
|
obj_filename = "missing_files_obj/model2.obj"
|
||||||
filename = os.path.join(DATA_DIR, obj_filename)
|
filename = os.path.join(DATA_DIR, obj_filename)
|
||||||
with self.assertWarnsRegex(Warning, "Mtl file does not exist"):
|
with self.assertWarnsRegex(UserWarning, "Mtl file does not exist"):
|
||||||
verts, faces, aux = load_obj(filename)
|
verts, faces, aux = load_obj(filename)
|
||||||
|
|
||||||
expected_verts = torch.tensor(
|
expected_verts = torch.tensor(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user