PLY load normals

Summary: Add ability to load normals when they are present in a PLY file.

Reviewed By: nikhilaravi

Differential Revision: D26458971

fbshipit-source-id: 658270b611f7624eab4f5f62ff438038e1d25723
This commit is contained in:
Jeremy Reizenstein 2021-05-04 05:35:24 -07:00 committed by Facebook GitHub Bot
parent b314beeda1
commit 6fa66f5534
2 changed files with 105 additions and 21 deletions

View File

@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
def _get_verts_column_indices( def _get_verts_column_indices(
vertex_head: _PlyElementType, vertex_head: _PlyElementType,
) -> Tuple[List[int], Optional[List[int]], float]: ) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
""" """
Get the columns of verts and verts_colors in the vertex Get the columns of verts, verts_colors, and verts_normals in the vertex
element of a parsed ply file, together with a color scale factor. element of a parsed ply file, together with a color scale factor.
When the colors are in byte format, they are scaled from 0..255 to [0,1]. When the colors are in byte format, they are scaled from 0..255 to [0,1].
Otherwise they are not scaled. Otherwise they are not scaled.
@ -793,11 +793,14 @@ def _get_verts_column_indices(
property double x property double x
property double y property double y
property double z property double z
property double nx
property double ny
property double nz
property uchar red property uchar red
property uchar green property uchar green
property uchar blue property uchar blue
then the return value will be ([0,1,2], [6,7,8], 1.0/255) then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
Args: Args:
vertex_head: as returned from load_ply_raw. vertex_head: as returned from load_ply_raw.
@ -807,9 +810,12 @@ def _get_verts_column_indices(
color_idxs: List[int] of 3 color columns if they are present, color_idxs: List[int] of 3 color columns if they are present,
otherwise None. otherwise None.
color_scale: value to scale colors by. color_scale: value to scale colors by.
normal_idxs: List[int] of 3 normals columns if they are present,
otherwise None.
""" """
point_idxs: List[Optional[int]] = [None, None, None] point_idxs: List[Optional[int]] = [None, None, None]
color_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None]
normal_idxs: List[Optional[int]] = [None, None, None]
for i, prop in enumerate(vertex_head.properties): for i, prop in enumerate(vertex_head.properties):
if prop.list_size_type is not None: if prop.list_size_type is not None:
raise ValueError("Invalid vertices in file: did not expect list.") raise ValueError("Invalid vertices in file: did not expect list.")
@ -819,6 +825,9 @@ def _get_verts_column_indices(
for j, name in enumerate(["red", "green", "blue"]): for j, name in enumerate(["red", "green", "blue"]):
if prop.name == name: if prop.name == name:
color_idxs[j] = i color_idxs[j] = i
for j, name in enumerate(["nx", "ny", "nz"]):
if prop.name == name:
normal_idxs[j] = i
if None in point_idxs: if None in point_idxs:
raise ValueError("Invalid vertices in file.") raise ValueError("Invalid vertices in file.")
color_scale = 1.0 color_scale = 1.0
@ -831,14 +840,15 @@ def _get_verts_column_indices(
point_idxs, point_idxs,
None if None in color_idxs else cast(List[int], color_idxs), None if None in color_idxs else cast(List[int], color_idxs),
color_scale, color_scale,
None if None in normal_idxs else cast(List[int], normal_idxs),
) )
def _get_verts( def _get_verts(
header: _PlyHeader, elements: dict header: _PlyHeader, elements: dict
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Get the vertex locations and colors from a parsed ply file. Get the vertex locations, colors and normals from a parsed ply file.
Args: Args:
header, elements: as returned from load_ply_raw. header, elements: as returned from load_ply_raw.
@ -846,6 +856,7 @@ def _get_verts(
Returns: Returns:
verts: FloatTensor of shape (V, 3). verts: FloatTensor of shape (V, 3).
vertex_colors: None or FloatTensor of shape (V, 3). vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
""" """
vertex = elements.get("vertex", None) vertex = elements.get("vertex", None)
@ -854,14 +865,16 @@ def _get_verts(
if not isinstance(vertex, list): if not isinstance(vertex, list):
raise ValueError("Invalid vertices in file.") raise ValueError("Invalid vertices in file.")
vertex_head = next(head for head in header.elements if head.name == "vertex") vertex_head = next(head for head in header.elements if head.name == "vertex")
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head) point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
vertex_head
)
# Case of no vertices # Case of no vertices
if vertex_head.count == 0: if vertex_head.count == 0:
verts = torch.zeros((0, 3), dtype=torch.float32) verts = torch.zeros((0, 3), dtype=torch.float32)
if color_idxs is None: if color_idxs is None:
return verts, None return verts, None, None
return verts, torch.zeros((0, 3), dtype=torch.float32) return verts, torch.zeros((0, 3), dtype=torch.float32), None
# Simple case where the only data is the vertices themselves # Simple case where the only data is the vertices themselves
if ( if (
@ -870,9 +883,10 @@ def _get_verts(
and vertex[0].ndim == 2 and vertex[0].ndim == 2
and vertex[0].shape[1] == 3 and vertex[0].shape[1] == 3
): ):
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
vertex_colors = None vertex_colors = None
vertex_normals = None
if len(vertex) == 1: if len(vertex) == 1:
# This is the case where the whole vertex element has one type, # This is the case where the whole vertex element has one type,
@ -882,6 +896,10 @@ def _get_verts(
vertex_colors = color_scale * torch.tensor( vertex_colors = color_scale * torch.tensor(
vertex[0][:, color_idxs], dtype=torch.float32 vertex[0][:, color_idxs], dtype=torch.float32
) )
if normal_idxs is not None:
vertex_normals = torch.tensor(
vertex[0][:, normal_idxs], dtype=torch.float32
)
else: else:
# The vertex element is heterogeneous. It was read as several arrays, # The vertex element is heterogeneous. It was read as several arrays,
# part by part, where a part is a set of properties with the same type. # part by part, where a part is a set of properties with the same type.
@ -913,13 +931,22 @@ def _get_verts(
partnum, col = prop_to_partnum_col[color_idxs[color]] partnum, col = prop_to_partnum_col[color_idxs[color]]
vertex_colors.numpy()[:, color] = vertex[partnum][:, col] vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
vertex_colors *= color_scale vertex_colors *= color_scale
if normal_idxs is not None:
vertex_normals = torch.empty(
size=(vertex_head.count, 3), dtype=torch.float32
)
for axis in range(3):
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
return verts, vertex_colors return verts, vertex_colors, vertex_normals
def _load_ply( def _load_ply(
f, *, path_manager: PathManager f, *, path_manager: PathManager
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
]:
""" """
Load the data from a .ply file. Load the data from a .ply file.
@ -935,10 +962,11 @@ def _load_ply(
verts: FloatTensor of shape (V, 3). verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3). faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3). vertex_colors: None or FloatTensor of shape (V, 3).
vertex_normals: None or FloatTensor of shape (V, 3).
""" """
header, elements = _load_ply_raw(f, path_manager=path_manager) header, elements = _load_ply_raw(f, path_manager=path_manager)
verts, vertex_colors = _get_verts(header, elements) verts, vertex_colors, vertex_normals = _get_verts(header, elements)
face = elements.get("face", None) face = elements.get("face", None)
if face is not None: if face is not None:
@ -976,7 +1004,7 @@ def _load_ply(
if faces is not None: if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
return verts, faces, vertex_colors return verts, faces, vertex_colors, vertex_normals
def load_ply( def load_ply(
@ -1031,7 +1059,7 @@ def load_ply(
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
verts, faces, _ = _load_ply(f, path_manager=path_manager) verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
if faces is None: if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64) faces = torch.zeros(0, 3, dtype=torch.int64)
@ -1211,18 +1239,23 @@ class MeshPlyFormat(MeshFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager) verts, faces, verts_colors, verts_normals = _load_ply(
f=path, path_manager=path_manager
)
if faces is None: if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64) faces = torch.zeros(0, 3, dtype=torch.int64)
textures = None texture = None
if include_textures and verts_colors is not None: if include_textures and verts_colors is not None:
textures = TexturesVertex([verts_colors.to(device)]) texture = TexturesVertex([verts_colors.to(device)])
if verts_normals is not None:
verts_normals = [verts_normals]
mesh = Meshes( mesh = Meshes(
verts=[verts.to(device)], verts=[verts.to(device)],
faces=[faces.to(device)], faces=[faces.to(device)],
textures=textures, textures=texture,
verts_normals=verts_normals,
) )
return mesh return mesh
@ -1286,12 +1319,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, features = _load_ply(f=path, path_manager=path_manager) verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
verts = verts.to(device) verts = verts.to(device)
if features is not None: if features is not None:
features = [features.to(device)] features = [features.to(device)]
if normals is not None:
normals = [normals.to(device)]
pointcloud = Pointclouds(points=[verts], features=features) pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
return pointcloud return pointcloud
def save( def save(

View File

@ -216,14 +216,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
) )
faces = torch.tensor([[0, 1, 2], [0, 2, 3]]) faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
)
vert_colors = torch.rand_like(verts) vert_colors = torch.rand_like(verts)
texture = TexturesVertex(verts_features=[vert_colors]) texture = TexturesVertex(verts_features=[vert_colors])
for do_textures in itertools.product([True, False]): for do_textures, do_normals in itertools.product([True, False], [True, False]):
mesh = Meshes( mesh = Meshes(
verts=[verts], verts=[verts],
faces=[faces], faces=[faces],
textures=texture if do_textures else None, textures=texture if do_textures else None,
verts_normals=[normals] if do_normals else None,
) )
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -236,12 +240,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
mesh2 = mesh2.cpu() mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), mesh.verts_padded()) self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
self.assertClose(mesh2.faces_padded(), mesh.faces_padded()) self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
if do_normals:
self.assertTrue(mesh.has_verts_normals())
self.assertTrue(mesh2.has_verts_normals())
self.assertClose(
mesh2.verts_normals_padded(), mesh.verts_normals_padded()
)
else:
self.assertFalse(mesh.has_verts_normals())
self.assertFalse(mesh2.has_verts_normals())
self.assertFalse(torch.allclose(mesh2.verts_normals_padded(), normals))
if do_textures: if do_textures:
self.assertIsInstance(mesh2.textures, TexturesVertex) self.assertIsInstance(mesh2.textures, TexturesVertex)
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors) self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
else: else:
self.assertIsNone(mesh2.textures) self.assertIsNone(mesh2.textures)
def test_save_load_with_normals(self):
points = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
)
features = torch.rand_like(points)
for do_features, do_normals in itertools.product([True, False], [True, False]):
cloud = Pointclouds(
points=[points],
features=[features] if do_features else None,
normals=[normals] if do_normals else None,
)
device = torch.device("cuda:0")
io = IO()
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
io.save_pointcloud(cloud.cuda(), f.name)
f.flush()
cloud2 = io.load_pointcloud(f.name, device=device)
self.assertEqual(cloud2.device, device)
cloud2 = cloud2.cpu()
self.assertClose(cloud2.points_padded(), cloud.points_padded())
if do_normals:
self.assertClose(cloud2.normals_padded(), cloud.normals_padded())
else:
self.assertIsNone(cloud.normals_padded())
self.assertIsNone(cloud2.normals_padded())
if do_features:
self.assertClose(cloud2.features_packed(), features)
else:
self.assertIsNone(cloud2.features_packed())
def test_save_ply_invalid_shapes(self): def test_save_ply_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error: