mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
PLY load normals
Summary: Add ability to load normals when they are present in a PLY file. Reviewed By: nikhilaravi Differential Revision: D26458971 fbshipit-source-id: 658270b611f7624eab4f5f62ff438038e1d25723
This commit is contained in:
parent
b314beeda1
commit
6fa66f5534
@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
|
||||
|
||||
def _get_verts_column_indices(
|
||||
vertex_head: _PlyElementType,
|
||||
) -> Tuple[List[int], Optional[List[int]], float]:
|
||||
) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]:
|
||||
"""
|
||||
Get the columns of verts and verts_colors in the vertex
|
||||
Get the columns of verts, verts_colors, and verts_normals in the vertex
|
||||
element of a parsed ply file, together with a color scale factor.
|
||||
When the colors are in byte format, they are scaled from 0..255 to [0,1].
|
||||
Otherwise they are not scaled.
|
||||
@ -793,11 +793,14 @@ def _get_verts_column_indices(
|
||||
property double x
|
||||
property double y
|
||||
property double z
|
||||
property double nx
|
||||
property double ny
|
||||
property double nz
|
||||
property uchar red
|
||||
property uchar green
|
||||
property uchar blue
|
||||
|
||||
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
|
||||
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
|
||||
|
||||
Args:
|
||||
vertex_head: as returned from load_ply_raw.
|
||||
@ -807,9 +810,12 @@ def _get_verts_column_indices(
|
||||
color_idxs: List[int] of 3 color columns if they are present,
|
||||
otherwise None.
|
||||
color_scale: value to scale colors by.
|
||||
normal_idxs: List[int] of 3 normals columns if they are present,
|
||||
otherwise None.
|
||||
"""
|
||||
point_idxs: List[Optional[int]] = [None, None, None]
|
||||
color_idxs: List[Optional[int]] = [None, None, None]
|
||||
normal_idxs: List[Optional[int]] = [None, None, None]
|
||||
for i, prop in enumerate(vertex_head.properties):
|
||||
if prop.list_size_type is not None:
|
||||
raise ValueError("Invalid vertices in file: did not expect list.")
|
||||
@ -819,6 +825,9 @@ def _get_verts_column_indices(
|
||||
for j, name in enumerate(["red", "green", "blue"]):
|
||||
if prop.name == name:
|
||||
color_idxs[j] = i
|
||||
for j, name in enumerate(["nx", "ny", "nz"]):
|
||||
if prop.name == name:
|
||||
normal_idxs[j] = i
|
||||
if None in point_idxs:
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
color_scale = 1.0
|
||||
@ -831,14 +840,15 @@ def _get_verts_column_indices(
|
||||
point_idxs,
|
||||
None if None in color_idxs else cast(List[int], color_idxs),
|
||||
color_scale,
|
||||
None if None in normal_idxs else cast(List[int], normal_idxs),
|
||||
)
|
||||
|
||||
|
||||
def _get_verts(
|
||||
header: _PlyHeader, elements: dict
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Get the vertex locations and colors from a parsed ply file.
|
||||
Get the vertex locations, colors and normals from a parsed ply file.
|
||||
|
||||
Args:
|
||||
header, elements: as returned from load_ply_raw.
|
||||
@ -846,6 +856,7 @@ def _get_verts(
|
||||
Returns:
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
|
||||
vertex = elements.get("vertex", None)
|
||||
@ -854,14 +865,16 @@ def _get_verts(
|
||||
if not isinstance(vertex, list):
|
||||
raise ValueError("Invalid vertices in file.")
|
||||
vertex_head = next(head for head in header.elements if head.name == "vertex")
|
||||
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
|
||||
point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices(
|
||||
vertex_head
|
||||
)
|
||||
|
||||
# Case of no vertices
|
||||
if vertex_head.count == 0:
|
||||
verts = torch.zeros((0, 3), dtype=torch.float32)
|
||||
if color_idxs is None:
|
||||
return verts, None
|
||||
return verts, torch.zeros((0, 3), dtype=torch.float32)
|
||||
return verts, None, None
|
||||
return verts, torch.zeros((0, 3), dtype=torch.float32), None
|
||||
|
||||
# Simple case where the only data is the vertices themselves
|
||||
if (
|
||||
@ -870,9 +883,10 @@ def _get_verts(
|
||||
and vertex[0].ndim == 2
|
||||
and vertex[0].shape[1] == 3
|
||||
):
|
||||
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None
|
||||
return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None
|
||||
|
||||
vertex_colors = None
|
||||
vertex_normals = None
|
||||
|
||||
if len(vertex) == 1:
|
||||
# This is the case where the whole vertex element has one type,
|
||||
@ -882,6 +896,10 @@ def _get_verts(
|
||||
vertex_colors = color_scale * torch.tensor(
|
||||
vertex[0][:, color_idxs], dtype=torch.float32
|
||||
)
|
||||
if normal_idxs is not None:
|
||||
vertex_normals = torch.tensor(
|
||||
vertex[0][:, normal_idxs], dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
# The vertex element is heterogeneous. It was read as several arrays,
|
||||
# part by part, where a part is a set of properties with the same type.
|
||||
@ -913,13 +931,22 @@ def _get_verts(
|
||||
partnum, col = prop_to_partnum_col[color_idxs[color]]
|
||||
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
|
||||
vertex_colors *= color_scale
|
||||
if normal_idxs is not None:
|
||||
vertex_normals = torch.empty(
|
||||
size=(vertex_head.count, 3), dtype=torch.float32
|
||||
)
|
||||
for axis in range(3):
|
||||
partnum, col = prop_to_partnum_col[normal_idxs[axis]]
|
||||
vertex_normals.numpy()[:, axis] = vertex[partnum][:, col]
|
||||
|
||||
return verts, vertex_colors
|
||||
return verts, vertex_colors, vertex_normals
|
||||
|
||||
|
||||
def _load_ply(
|
||||
f, *, path_manager: PathManager
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]
|
||||
]:
|
||||
"""
|
||||
Load the data from a .ply file.
|
||||
|
||||
@ -935,10 +962,11 @@ def _load_ply(
|
||||
verts: FloatTensor of shape (V, 3).
|
||||
faces: None or LongTensor of vertex indices, shape (F, 3).
|
||||
vertex_colors: None or FloatTensor of shape (V, 3).
|
||||
vertex_normals: None or FloatTensor of shape (V, 3).
|
||||
"""
|
||||
header, elements = _load_ply_raw(f, path_manager=path_manager)
|
||||
|
||||
verts, vertex_colors = _get_verts(header, elements)
|
||||
verts, vertex_colors, vertex_normals = _get_verts(header, elements)
|
||||
|
||||
face = elements.get("face", None)
|
||||
if face is not None:
|
||||
@ -976,7 +1004,7 @@ def _load_ply(
|
||||
if faces is not None:
|
||||
_check_faces_indices(faces, max_index=verts.shape[0])
|
||||
|
||||
return verts, faces, vertex_colors
|
||||
return verts, faces, vertex_colors, vertex_normals
|
||||
|
||||
|
||||
def load_ply(
|
||||
@ -1031,7 +1059,7 @@ def load_ply(
|
||||
|
||||
if path_manager is None:
|
||||
path_manager = PathManager()
|
||||
verts, faces, _ = _load_ply(f, path_manager=path_manager)
|
||||
verts, faces, _, _ = _load_ply(f, path_manager=path_manager)
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
@ -1211,18 +1239,23 @@ class MeshPlyFormat(MeshFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
|
||||
verts, faces, verts_colors, verts_normals = _load_ply(
|
||||
f=path, path_manager=path_manager
|
||||
)
|
||||
if faces is None:
|
||||
faces = torch.zeros(0, 3, dtype=torch.int64)
|
||||
|
||||
textures = None
|
||||
texture = None
|
||||
if include_textures and verts_colors is not None:
|
||||
textures = TexturesVertex([verts_colors.to(device)])
|
||||
texture = TexturesVertex([verts_colors.to(device)])
|
||||
|
||||
if verts_normals is not None:
|
||||
verts_normals = [verts_normals]
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)],
|
||||
faces=[faces.to(device)],
|
||||
textures=textures,
|
||||
textures=texture,
|
||||
verts_normals=verts_normals,
|
||||
)
|
||||
return mesh
|
||||
|
||||
@ -1286,12 +1319,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
|
||||
if not endswith(path, self.known_suffixes):
|
||||
return None
|
||||
|
||||
verts, faces, features = _load_ply(f=path, path_manager=path_manager)
|
||||
verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager)
|
||||
verts = verts.to(device)
|
||||
if features is not None:
|
||||
features = [features.to(device)]
|
||||
if normals is not None:
|
||||
normals = [normals.to(device)]
|
||||
|
||||
pointcloud = Pointclouds(points=[verts], features=features)
|
||||
pointcloud = Pointclouds(points=[verts], features=features, normals=normals)
|
||||
return pointcloud
|
||||
|
||||
def save(
|
||||
|
@ -216,14 +216,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||
normals = torch.tensor(
|
||||
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
vert_colors = torch.rand_like(verts)
|
||||
texture = TexturesVertex(verts_features=[vert_colors])
|
||||
|
||||
for do_textures in itertools.product([True, False]):
|
||||
for do_textures, do_normals in itertools.product([True, False], [True, False]):
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces],
|
||||
textures=texture if do_textures else None,
|
||||
verts_normals=[normals] if do_normals else None,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
@ -236,12 +240,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
mesh2 = mesh2.cpu()
|
||||
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
|
||||
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
|
||||
if do_normals:
|
||||
self.assertTrue(mesh.has_verts_normals())
|
||||
self.assertTrue(mesh2.has_verts_normals())
|
||||
self.assertClose(
|
||||
mesh2.verts_normals_padded(), mesh.verts_normals_padded()
|
||||
)
|
||||
else:
|
||||
self.assertFalse(mesh.has_verts_normals())
|
||||
self.assertFalse(mesh2.has_verts_normals())
|
||||
self.assertFalse(torch.allclose(mesh2.verts_normals_padded(), normals))
|
||||
if do_textures:
|
||||
self.assertIsInstance(mesh2.textures, TexturesVertex)
|
||||
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
|
||||
else:
|
||||
self.assertIsNone(mesh2.textures)
|
||||
|
||||
def test_save_load_with_normals(self):
|
||||
points = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
normals = torch.tensor(
|
||||
[[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
features = torch.rand_like(points)
|
||||
|
||||
for do_features, do_normals in itertools.product([True, False], [True, False]):
|
||||
cloud = Pointclouds(
|
||||
points=[points],
|
||||
features=[features] if do_features else None,
|
||||
normals=[normals] if do_normals else None,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
io = IO()
|
||||
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
|
||||
io.save_pointcloud(cloud.cuda(), f.name)
|
||||
f.flush()
|
||||
cloud2 = io.load_pointcloud(f.name, device=device)
|
||||
self.assertEqual(cloud2.device, device)
|
||||
cloud2 = cloud2.cpu()
|
||||
self.assertClose(cloud2.points_padded(), cloud.points_padded())
|
||||
if do_normals:
|
||||
self.assertClose(cloud2.normals_padded(), cloud.normals_padded())
|
||||
else:
|
||||
self.assertIsNone(cloud.normals_padded())
|
||||
self.assertIsNone(cloud2.normals_padded())
|
||||
if do_features:
|
||||
self.assertClose(cloud2.features_packed(), features)
|
||||
else:
|
||||
self.assertIsNone(cloud2.features_packed())
|
||||
|
||||
def test_save_ply_invalid_shapes(self):
|
||||
# Invalid vertices shape
|
||||
with self.assertRaises(ValueError) as error:
|
||||
|
Loading…
x
Reference in New Issue
Block a user