PLY TexturesVertex loading

Summary:
Include TexturesVertex colors when loading and saving Meshes to PLY files.

A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data.

Reviewed By: gkioxari

Differential Revision: D27765260

fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
This commit is contained in:
Jeremy Reizenstein
2021-05-04 05:35:24 -07:00
committed by Facebook GitHub Bot
parent 097b0ef2c6
commit 6c3fe952d1
2 changed files with 135 additions and 53 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
import struct
import unittest
from io import BytesIO, StringIO
@@ -12,7 +13,8 @@ from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager
from pytorch3d.io import IO
from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import torus
@@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
):
io.load_mesh(f3.name)
def test_save_too_many_colors(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand((4, 7))
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture_with_seven_colors,
)
io = IO()
msg = "Texture will not be saved as it has 7 colors, not 3."
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
with self.assertWarnsRegex(UserWarning, msg):
io.save_mesh(mesh.cuda(), f.name)
def test_save_load_meshes(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand_like(verts)
texture = TexturesVertex(verts_features=[vert_colors])
for do_textures in itertools.product([True, False]):
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture if do_textures else None,
)
device = torch.device("cuda:0")
io = IO()
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
io.save_mesh(mesh.cuda(), f.name)
f.flush()
mesh2 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh2.device, device)
mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
if do_textures:
self.assertIsInstance(mesh2.textures, TexturesVertex)
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
else:
self.assertIsNone(mesh2.textures)
def test_save_ply_invalid_shapes(self):
# Invalid vertices shape
with self.assertRaises(ValueError) as error: