allow saving vertex normal in save_obj (#1511)

Summary:
Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`.

This patch fixes this by allowing passing per-vertex normal data in `save_obj`:
``` python
def save_obj(
    f: PathOrStr,
    verts,
    faces,
    decimal_places: Optional[int] = None,
    path_manager: Optional[PathManager] = None,
    *,
    verts_normals: Optional[torch.Tensor] = None,
    faces_normals: Optional[torch.Tensor] = None,
    verts_uvs: Optional[torch.Tensor] = None,
    faces_uvs: Optional[torch.Tensor] = None,
    texture_map: Optional[torch.Tensor] = None,
) -> None:
    """
    Save a mesh to an .obj file.

    Args:
        f: File (str or path) to which the mesh should be written.
        verts: FloatTensor of shape (V, 3) giving vertex coordinates.
        faces: LongTensor of shape (F, 3) giving faces.
        decimal_places: Number of decimal places for saving.
        path_manager: Optional PathManager for interpreting f if
            it is a str.
        verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
        faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
            for each vertex in the face.
        verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
        faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
            each vertex in the face.
        texture_map: FloatTensor of shape (H, W, 3) representing the texture map
            for the mesh which will be saved as an image. The values are expected
            to be in the range [0, 1],
    """
```

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511

Reviewed By: shapovalov

Differential Revision: D45086045

Pulled By: bottler

fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
This commit is contained in:
dhb
2023-05-07 06:32:02 -07:00
committed by Facebook GitHub Bot
parent ec87284c4b
commit 092400f1e7
2 changed files with 262 additions and 32 deletions

View File

@@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
"f 4 2 1",
]
)
actual_file = open(Path(f.name), "r")
self.assertEqual(actual_file.read(), expected_file)
self.assertEqual(Path(f.name).read_text(), expected_file)
def test_load_mtl(self):
obj_filename = "cow_mesh/cow.obj"
@@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
def test_save_obj_with_normal(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
[0.40, 0.7, 0.19],
[1.0, 0.00, 0.000],
[0.00, 1.00, 0.00],
[0.00, 0.00, 1.0],
],
dtype=torch.float32,
)
faces_normals_idx = torch.tensor(
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
)
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
)
expected_obj_file = "\n".join(
[
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vn 0.40 0.70 0.19",
"vn 1.00 0.00 0.00",
"vn 0.00 1.00 0.00",
"vn 0.00 0.00 1.00",
"f 1//1 3//2 2//3",
"f 1//3 2//4 3//5",
"f 4//5 3//6 2//7",
"f 4//7 2//8 1//1",
]
)
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
def test_save_obj_with_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
@@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
self.assertClose(texture_image, texture_map)
def test_save_obj_with_normal_and_texture(self):
verts = torch.tensor(
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
dtype=torch.float32,
)
faces = torch.tensor(
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
)
normals = torch.tensor(
[
[0.02, 0.5, 0.73],
[0.3, 0.03, 0.361],
[0.32, 0.12, 0.47],
[0.36, 0.17, 0.9],
],
dtype=torch.float32,
)
faces_normals_idx = faces
verts_uvs = torch.tensor(
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
dtype=torch.float32,
)
faces_uvs = faces
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
with TemporaryDirectory() as temp_dir:
obj_file = os.path.join(temp_dir, "mesh.obj")
save_obj(
obj_file,
verts,
faces,
decimal_places=2,
normals=normals,
faces_normals_idx=faces_normals_idx,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map,
)
expected_obj_file = "\n".join(
[
"",
"mtllib mesh.mtl",
"usemtl mesh",
"",
"v 0.01 0.20 0.30",
"v 0.20 0.03 0.41",
"v 0.30 0.40 0.05",
"v 0.60 0.70 0.80",
"vn 0.02 0.50 0.73",
"vn 0.30 0.03 0.36",
"vn 0.32 0.12 0.47",
"vn 0.36 0.17 0.90",
"vt 0.02 0.50",
"vt 0.30 0.03",
"vt 0.32 0.12",
"vt 0.36 0.17",
"f 1/1/1 3/3/3 2/2/2",
"f 1/1/1 2/2/2 3/3/3",
"f 4/4/4 3/3/3 2/2/2",
"f 4/4/4 2/2/2 1/1/1",
]
)
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
# Check there are only 3 files in the temp dir
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
tempfiles_dir = os.listdir(temp_dir)
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)
@@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(tempfiles, tempfiles_dir)
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
obj_file = StringIO()
with self.assertRaises(ValueError):
@@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
# Check the obj file is saved correctly
actual_file = open(obj_file, "r")
self.assertEqual(actual_file.read(), expected_obj_file)
with open(obj_file, "r") as actual_file:
self.assertEqual(actual_file.read(), expected_obj_file)
# Check the mtl file is saved correctly
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
mtl_file = open(mtl_file_name, "r")
self.assertEqual(mtl_file.read(), expected_mtl_file)
with open(mtl_file_name, "r") as mtl_file:
self.assertEqual(mtl_file.read(), expected_mtl_file)
# Check the texture image file is saved correctly
texture_image = load_rgb_image("mesh.png", temp_dir)