mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
allow saving vertex normal in save_obj (#1511)
Summary: Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`. This patch fixes this by allowing passing per-vertex normal data in `save_obj`: ``` python def save_obj( f: PathOrStr, verts, faces, decimal_places: Optional[int] = None, path_manager: Optional[PathManager] = None, *, verts_normals: Optional[torch.Tensor] = None, faces_normals: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None, texture_map: Optional[torch.Tensor] = None, ) -> None: """ Save a mesh to an .obj file. Args: f: File (str or path) to which the mesh should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. decimal_places: Number of decimal places for saving. path_manager: Optional PathManager for interpreting f if it is a str. verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex. faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals for each vertex in the face. verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for each vertex in the face. texture_map: FloatTensor of shape (H, W, 3) representing the texture map for the mesh which will be saved as an image. The values are expected to be in the range [0, 1], """ ``` Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511 Reviewed By: shapovalov Differential Revision: D45086045 Pulled By: bottler fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
This commit is contained in:
parent
ec87284c4b
commit
092400f1e7
@ -684,6 +684,8 @@ def save_obj(
|
||||
decimal_places: Optional[int] = None,
|
||||
path_manager: Optional[PathManager] = None,
|
||||
*,
|
||||
normals: Optional[torch.Tensor] = None,
|
||||
faces_normals_idx: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
texture_map: Optional[torch.Tensor] = None,
|
||||
@ -698,6 +700,10 @@ def save_obj(
|
||||
decimal_places: Number of decimal places for saving.
|
||||
path_manager: Optional PathManager for interpreting f if
|
||||
it is a str.
|
||||
normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
|
||||
to index into.
|
||||
faces_normals_idx: LongTensor of shape (F, 3) giving the index into
|
||||
normals for each vertex in the face.
|
||||
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
|
||||
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
|
||||
each vertex in the face.
|
||||
@ -713,6 +719,22 @@ def save_obj(
|
||||
message = "'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if (normals is None) != (faces_normals_idx is None):
|
||||
message = "'normals' and 'faces_normals_idx' must both be None or neither."
|
||||
raise ValueError(message)
|
||||
|
||||
if faces_normals_idx is not None and (
|
||||
faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
|
||||
):
|
||||
message = (
|
||||
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
|
||||
message = "'normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
|
||||
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
@ -742,9 +764,12 @@ def save_obj(
|
||||
verts,
|
||||
faces,
|
||||
decimal_places,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
save_texture=save_texture,
|
||||
save_normals=normals is not None,
|
||||
)
|
||||
|
||||
# Save the .mtl and .png files associated with the texture
|
||||
@ -777,9 +802,12 @@ def _save(
|
||||
faces,
|
||||
decimal_places: Optional[int] = None,
|
||||
*,
|
||||
normals: Optional[torch.Tensor] = None,
|
||||
faces_normals_idx: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
save_texture: bool = False,
|
||||
save_normals: bool = False,
|
||||
) -> None:
|
||||
|
||||
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
|
||||
@ -798,18 +826,26 @@ def _save(
|
||||
|
||||
lines = ""
|
||||
|
||||
if len(verts):
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
if decimal_places is None:
|
||||
float_str = "%f"
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
|
||||
if len(verts):
|
||||
V, D = verts.shape
|
||||
for i in range(V):
|
||||
vert = [float_str % verts[i, j] for j in range(D)]
|
||||
lines += "v %s\n" % " ".join(vert)
|
||||
|
||||
if save_normals:
|
||||
assert normals is not None
|
||||
assert faces_normals_idx is not None
|
||||
lines += _write_normals(normals, faces_normals_idx, float_str)
|
||||
|
||||
if save_texture:
|
||||
assert faces_uvs is not None
|
||||
assert verts_uvs is not None
|
||||
|
||||
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
|
||||
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
@ -818,7 +854,6 @@ def _save(
|
||||
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
|
||||
raise ValueError(message)
|
||||
|
||||
# pyre-fixme[16] # undefined attribute cpu
|
||||
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
|
||||
|
||||
# Save verts uvs after verts
|
||||
@ -828,25 +863,77 @@ def _save(
|
||||
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
|
||||
lines += "vt %s\n" % " ".join(uv)
|
||||
|
||||
f.write(lines)
|
||||
|
||||
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
if len(faces):
|
||||
F, P = faces.shape
|
||||
for i in range(F):
|
||||
if save_texture:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}
|
||||
_write_faces(
|
||||
f,
|
||||
faces,
|
||||
faces_uvs if save_texture else None,
|
||||
faces_normals_idx if save_normals else None,
|
||||
)
|
||||
|
||||
|
||||
def _write_normals(
|
||||
normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
|
||||
) -> str:
|
||||
if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
|
||||
message = (
|
||||
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
if normals.dim() != 2 or normals.size(1) != 3:
|
||||
message = "'normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
|
||||
|
||||
lines = []
|
||||
V, D = normals.shape
|
||||
for i in range(V):
|
||||
normal = [float_str % normals[i, j] for j in range(D)]
|
||||
lines.append("vn %s\n" % " ".join(normal))
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def _write_faces(
|
||||
f,
|
||||
faces: torch.Tensor,
|
||||
faces_uvs: Optional[torch.Tensor],
|
||||
faces_normals_idx: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
F, P = faces.shape
|
||||
for i in range(F):
|
||||
if faces_normals_idx is not None:
|
||||
if faces_uvs is not None:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
|
||||
face = [
|
||||
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
|
||||
"%d/%d/%d"
|
||||
% (
|
||||
faces[i, j] + 1,
|
||||
faces_uvs[i, j] + 1,
|
||||
faces_normals_idx[i, j] + 1,
|
||||
)
|
||||
for j in range(P)
|
||||
]
|
||||
else:
|
||||
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
||||
# Format faces as {verts_idx}//{verts_normals_idx}
|
||||
face = [
|
||||
"%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
|
||||
for j in range(P)
|
||||
]
|
||||
elif faces_uvs is not None:
|
||||
# Format faces as {verts_idx}/{verts_uvs_idx}
|
||||
face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
|
||||
else:
|
||||
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
|
||||
|
||||
if i + 1 < F:
|
||||
lines += "f %s\n" % " ".join(face)
|
||||
|
||||
elif i + 1 == F:
|
||||
# No newline at the end of the file.
|
||||
lines += "f %s" % " ".join(face)
|
||||
|
||||
f.write(lines)
|
||||
if i + 1 < F:
|
||||
f.write("f %s\n" % " ".join(face))
|
||||
else:
|
||||
# No newline at the end of the file.
|
||||
f.write("f %s" % " ".join(face))
|
||||
|
@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
"f 4 2 1",
|
||||
]
|
||||
)
|
||||
actual_file = open(Path(f.name), "r")
|
||||
self.assertEqual(actual_file.read(), expected_file)
|
||||
self.assertEqual(Path(f.name).read_text(), expected_file)
|
||||
|
||||
def test_load_mtl(self):
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "same type of texture"):
|
||||
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
|
||||
|
||||
def test_save_obj_with_normal(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
[0.40, 0.7, 0.19],
|
||||
[1.0, 0.00, 0.000],
|
||||
[0.00, 1.00, 0.00],
|
||||
[0.00, 0.00, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = torch.tensor(
|
||||
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
|
||||
)
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vn 0.40 0.70 0.19",
|
||||
"vn 1.00 0.00 0.00",
|
||||
"vn 0.00 1.00 0.00",
|
||||
"vn 0.00 0.00 1.00",
|
||||
"f 1//1 3//2 2//3",
|
||||
"f 1//3 2//4 3//5",
|
||||
"f 4//5 3//6 2//7",
|
||||
"f 4//7 2//8 1//1",
|
||||
]
|
||||
)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
def test_save_obj_with_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
self.assertClose(texture_image, texture_map)
|
||||
|
||||
def test_save_obj_with_normal_and_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = faces
|
||||
verts_uvs = torch.tensor(
|
||||
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_uvs = faces
|
||||
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
texture_map=texture_map,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"",
|
||||
"mtllib mesh.mtl",
|
||||
"usemtl mesh",
|
||||
"",
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vt 0.02 0.50",
|
||||
"vt 0.30 0.03",
|
||||
"vt 0.32 0.12",
|
||||
"vt 0.36 0.17",
|
||||
"f 1/1/1 3/3/3 2/2/2",
|
||||
"f 1/1/1 2/2/2 3/3/3",
|
||||
"f 4/4/4 3/3/3 2/2/2",
|
||||
"f 4/4/4 2/2/2 1/1/1",
|
||||
]
|
||||
)
|
||||
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
|
||||
|
||||
# Check there are only 3 files in the temp dir
|
||||
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
|
||||
tempfiles_dir = os.listdir(temp_dir)
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(tempfiles, tempfiles_dir)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
obj_file = StringIO()
|
||||
with self.assertRaises(ValueError):
|
||||
@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user