allow saving vertex normal in save_obj (#1511)

Summary:
Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`.

This patch fixes this by allowing passing per-vertex normal data in `save_obj`:
``` python
def save_obj(
    f: PathOrStr,
    verts,
    faces,
    decimal_places: Optional[int] = None,
    path_manager: Optional[PathManager] = None,
    *,
    verts_normals: Optional[torch.Tensor] = None,
    faces_normals: Optional[torch.Tensor] = None,
    verts_uvs: Optional[torch.Tensor] = None,
    faces_uvs: Optional[torch.Tensor] = None,
    texture_map: Optional[torch.Tensor] = None,
) -> None:
    """
    Save a mesh to an .obj file.

    Args:
        f: File (str or path) to which the mesh should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shape (F, 3) giving faces.
        decimal_places: Number of decimal places for saving.
        path_manager: Optional PathManager for interpreting f if
            it is a str.
        verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
        faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
            for each vertex in the face.
        verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
        faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
            each vertex in the face.
        texture_map: FloatTensor of shape (H, W, 3) representing the texture map
            for the mesh which will be saved as an image. The values are expected
            to be in the range [0, 1],
    """
```

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511

Reviewed By: shapovalov

Differential Revision: D45086045

Pulled By: bottler

fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
This commit is contained in:
dhb 2023-05-07 06:32:02 -07:00 committed by Facebook GitHub Bot
parent ec87284c4b
commit 092400f1e7
2 changed files with 262 additions and 32 deletions

View File

@ -684,6 +684,8 @@ def save_obj(
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
*, *,
normals: Optional[torch.Tensor] = None,
faces_normals_idx: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None, texture_map: Optional[torch.Tensor] = None,
@ -698,6 +700,10 @@ def save_obj(
decimal_places: Number of decimal places for saving. decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if path_manager: Optional PathManager for interpreting f if
it is a str. it is a str.
normals: FloatTensor of shape (V, 3) giving normals for faces_normals_idx
to index into.
faces_normals_idx: LongTensor of shape (F, 3) giving the index into
normals for each vertex in the face.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex. verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face. each vertex in the face.
@ -713,6 +719,22 @@ def save_obj(
message = "'faces' should either be empty or of shape (num_faces, 3)." message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
if (normals is None) != (faces_normals_idx is None):
message = "'normals' and 'faces_normals_idx' must both be None or neither."
raise ValueError(message)
if faces_normals_idx is not None and (
faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3
):
message = (
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
)
raise ValueError(message)
if normals is not None and (normals.dim() != 2 or normals.size(1) != 3):
message = "'normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
@ -742,9 +764,12 @@ def save_obj(
verts, verts,
faces, faces,
decimal_places, decimal_places,
normals=normals,
faces_normals_idx=faces_normals_idx,
verts_uvs=verts_uvs, verts_uvs=verts_uvs,
faces_uvs=faces_uvs, faces_uvs=faces_uvs,
save_texture=save_texture, save_texture=save_texture,
save_normals=normals is not None,
) )
# Save the .mtl and .png files associated with the texture # Save the .mtl and .png files associated with the texture
@ -777,9 +802,12 @@ def _save(
faces, faces,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
*, *,
normals: Optional[torch.Tensor] = None,
faces_normals_idx: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False, save_texture: bool = False,
save_normals: bool = False,
) -> None: ) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
@ -798,18 +826,26 @@ def _save(
lines = "" lines = ""
if len(verts): if decimal_places is None:
if decimal_places is None: float_str = "%f"
float_str = "%f" else:
else: float_str = "%" + ".%df" % decimal_places
float_str = "%" + ".%df" % decimal_places
if len(verts):
V, D = verts.shape V, D = verts.shape
for i in range(V): for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)] vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert) lines += "v %s\n" % " ".join(vert)
if save_normals:
assert normals is not None
assert faces_normals_idx is not None
lines += _write_normals(normals, faces_normals_idx, float_str)
if save_texture: if save_texture:
assert faces_uvs is not None
assert verts_uvs is not None
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3): if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)." message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
@ -818,7 +854,6 @@ def _save(
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)." message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message) raise ValueError(message)
# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu() verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()
# Save verts uvs after verts # Save verts uvs after verts
@ -828,25 +863,77 @@ def _save(
uv = [float_str % verts_uvs[i, j] for j in range(uD)] uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv) lines += "vt %s\n" % " ".join(uv)
f.write(lines)
if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0): if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices") warnings.warn("Faces have invalid indices")
if len(faces): if len(faces):
F, P = faces.shape _write_faces(
for i in range(F): f,
if save_texture: faces,
# Format faces as {verts_idx}/{verts_uvs_idx} faces_uvs if save_texture else None,
faces_normals_idx if save_normals else None,
)
def _write_normals(
normals: torch.Tensor, faces_normals_idx: torch.Tensor, float_str: str
) -> str:
if faces_normals_idx.dim() != 2 or faces_normals_idx.size(1) != 3:
message = (
"'faces_normals_idx' should either be empty or of shape (num_faces, 3)."
)
raise ValueError(message)
if normals.dim() != 2 or normals.size(1) != 3:
message = "'normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
normals, faces_normals_idx = normals.cpu(), faces_normals_idx.cpu()
lines = []
V, D = normals.shape
for i in range(V):
normal = [float_str % normals[i, j] for j in range(D)]
lines.append("vn %s\n" % " ".join(normal))
return "".join(lines)
def _write_faces(
f,
faces: torch.Tensor,
faces_uvs: Optional[torch.Tensor],
faces_normals_idx: Optional[torch.Tensor],
) -> None:
F, P = faces.shape
for i in range(F):
if faces_normals_idx is not None:
if faces_uvs is not None:
# Format faces as {verts_idx}/{verts_uvs_idx}/{verts_normals_idx}
face = [ face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P) "%d/%d/%d"
% (
faces[i, j] + 1,
faces_uvs[i, j] + 1,
faces_normals_idx[i, j] + 1,
)
for j in range(P)
] ]
else: else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)] # Format faces as {verts_idx}//{verts_normals_idx}
face = [
"%d//%d" % (faces[i, j] + 1, faces_normals_idx[i, j] + 1)
for j in range(P)
]
elif faces_uvs is not None:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = ["%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F: if i + 1 < F:
lines += "f %s\n" % " ".join(face) f.write("f %s\n" % " ".join(face))
else:
elif i + 1 == F: # No newline at the end of the file.
# No newline at the end of the file. f.write("f %s" % " ".join(face))
lines += "f %s" % " ".join(face)
f.write(lines)

View File

@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 4 2 1", "f 4 2 1",
] ]
) )
actual_file = open(Path(f.name), "r") self.assertEqual(Path(f.name).read_text(), expected_file)
self.assertEqual(actual_file.read(), expected_file)
def test_load_mtl(self): def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj" obj_filename = "cow_mesh/cow.obj"
@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"): with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_normal(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
[0.40, 0.7, 0.19],
[1.0, 0.00, 0.000],
[0.00, 1.00, 0.00],
[0.00, 0.00, 1.0],
],
dtype=torch.float32,
)
faces_normals_idx = torch.tensor(
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vn 0.40 0.70 0.19",
"vn 1.00 0.00 0.00",
"vn 0.00 1.00 0.00",
"vn 0.00 0.00 1.00",
"f 1//1 3//2 2//3",
"f 1//3 2//4 3//5",
"f 4//5 3//6 2//7",
"f 4//7 2//8 1//1",
]
)
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
def test_save_obj_with_texture(self): def test_save_obj_with_texture(self):
verts = torch.tensor( verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]], [[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly # Check the obj file is saved correctly
actual_file = open(obj_file, "r") with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file) self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly # Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl") mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r") with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file) self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_normal_and_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
],
dtype=torch.float32,
)
faces_normals_idx = faces
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1/1 3/3/3 2/2/2",
"f 1/1/1 2/2/2 3/3/3",
"f 4/4/4 3/3/3 2/2/2",
"f 4/4/4 2/2/2 1/1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly # Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir) texture_image = load_rgb_image("mesh.png", temp_dir)
@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(tempfiles, tempfiles_dir) self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly # Check the obj file is saved correctly
actual_file = open(obj_file, "r") with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file) self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO() obj_file = StringIO()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir)) self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly # Check the obj file is saved correctly
actual_file = open(obj_file, "r") with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file) self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly # Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl") mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r") with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file) self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly # Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir) texture_image = load_rgb_image("mesh.png", temp_dir)