mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
allow saving vertex normal in save_obj (#1511)
Summary:
Although we can load per-vertex normals in `load_obj`, saving per-vertex normals is not supported in `save_obj`.
This patch fixes this by allowing passing per-vertex normal data in `save_obj`:
``` python
def save_obj(
f: PathOrStr,
verts,
faces,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
*,
verts_normals: Optional[torch.Tensor] = None,
faces_normals: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
) -> None:
"""
Save a mesh to an .obj file.
Args:
f: File (str or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
verts_normals: FloatTensor of shape (V, 3) giving the normal per vertex.
faces_normals: LongTensor of shape (F, 3) giving the index into verts_normals
for each vertex in the face.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
for the mesh which will be saved as an image. The values are expected
to be in the range [0, 1],
"""
```
Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1511
Reviewed By: shapovalov
Differential Revision: D45086045
Pulled By: bottler
fbshipit-source-id: 666efb0d2c302df6cf9f2f6601d83a07856bf32f
This commit is contained in:
@@ -532,8 +532,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
"f 4 2 1",
|
||||
]
|
||||
)
|
||||
actual_file = open(Path(f.name), "r")
|
||||
self.assertEqual(actual_file.read(), expected_file)
|
||||
self.assertEqual(Path(f.name).read_text(), expected_file)
|
||||
|
||||
def test_load_mtl(self):
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
@@ -895,6 +894,67 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "same type of texture"):
|
||||
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
|
||||
|
||||
def test_save_obj_with_normal(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
[0.40, 0.7, 0.19],
|
||||
[1.0, 0.00, 0.000],
|
||||
[0.00, 1.00, 0.00],
|
||||
[0.00, 0.00, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = torch.tensor(
|
||||
[[0, 1, 2], [2, 3, 4], [4, 5, 6], [6, 7, 0]], dtype=torch.int64
|
||||
)
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vn 0.40 0.70 0.19",
|
||||
"vn 1.00 0.00 0.00",
|
||||
"vn 0.00 1.00 0.00",
|
||||
"vn 0.00 0.00 1.00",
|
||||
"f 1//1 3//2 2//3",
|
||||
"f 1//3 2//4 3//5",
|
||||
"f 4//5 3//6 2//7",
|
||||
"f 4//7 2//8 1//1",
|
||||
]
|
||||
)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
def test_save_obj_with_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
@@ -950,13 +1010,96 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
self.assertClose(texture_image, texture_map)
|
||||
|
||||
def test_save_obj_with_normal_and_texture(self):
|
||||
verts = torch.tensor(
|
||||
[[0.01, 0.2, 0.301], [0.2, 0.03, 0.408], [0.3, 0.4, 0.05], [0.6, 0.7, 0.8]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[[0, 2, 1], [0, 1, 2], [3, 2, 1], [3, 1, 0]], dtype=torch.int64
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[
|
||||
[0.02, 0.5, 0.73],
|
||||
[0.3, 0.03, 0.361],
|
||||
[0.32, 0.12, 0.47],
|
||||
[0.36, 0.17, 0.9],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_normals_idx = faces
|
||||
verts_uvs = torch.tensor(
|
||||
[[0.02, 0.5], [0.3, 0.03], [0.32, 0.12], [0.36, 0.17]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
faces_uvs = faces
|
||||
texture_map = torch.randint(size=(2, 2, 3), high=255) / 255.0
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
obj_file = os.path.join(temp_dir, "mesh.obj")
|
||||
save_obj(
|
||||
obj_file,
|
||||
verts,
|
||||
faces,
|
||||
decimal_places=2,
|
||||
normals=normals,
|
||||
faces_normals_idx=faces_normals_idx,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
texture_map=texture_map,
|
||||
)
|
||||
|
||||
expected_obj_file = "\n".join(
|
||||
[
|
||||
"",
|
||||
"mtllib mesh.mtl",
|
||||
"usemtl mesh",
|
||||
"",
|
||||
"v 0.01 0.20 0.30",
|
||||
"v 0.20 0.03 0.41",
|
||||
"v 0.30 0.40 0.05",
|
||||
"v 0.60 0.70 0.80",
|
||||
"vn 0.02 0.50 0.73",
|
||||
"vn 0.30 0.03 0.36",
|
||||
"vn 0.32 0.12 0.47",
|
||||
"vn 0.36 0.17 0.90",
|
||||
"vt 0.02 0.50",
|
||||
"vt 0.30 0.03",
|
||||
"vt 0.32 0.12",
|
||||
"vt 0.36 0.17",
|
||||
"f 1/1/1 3/3/3 2/2/2",
|
||||
"f 1/1/1 2/2/2 3/3/3",
|
||||
"f 4/4/4 3/3/3 2/2/2",
|
||||
"f 4/4/4 2/2/2 1/1/1",
|
||||
]
|
||||
)
|
||||
expected_mtl_file = "\n".join(["newmtl mesh", "map_Kd mesh.png", ""])
|
||||
|
||||
# Check there are only 3 files in the temp dir
|
||||
tempfiles = ["mesh.obj", "mesh.png", "mesh.mtl"]
|
||||
tempfiles_dir = os.listdir(temp_dir)
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
@@ -1013,8 +1156,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(tempfiles, tempfiles_dir)
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
obj_file = StringIO()
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -1100,13 +1243,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(Counter(tempfiles), Counter(tempfiles_dir))
|
||||
|
||||
# Check the obj file is saved correctly
|
||||
actual_file = open(obj_file, "r")
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
with open(obj_file, "r") as actual_file:
|
||||
self.assertEqual(actual_file.read(), expected_obj_file)
|
||||
|
||||
# Check the mtl file is saved correctly
|
||||
mtl_file_name = os.path.join(temp_dir, "mesh.mtl")
|
||||
mtl_file = open(mtl_file_name, "r")
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
with open(mtl_file_name, "r") as mtl_file:
|
||||
self.assertEqual(mtl_file.read(), expected_mtl_file)
|
||||
|
||||
# Check the texture image file is saved correctly
|
||||
texture_image = load_rgb_image("mesh.png", temp_dir)
|
||||
|
||||
Reference in New Issue
Block a user