diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 29c846e7..73bbdd92 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]: def _get_verts_column_indices( vertex_head: _PlyElementType, -) -> Tuple[List[int], Optional[List[int]], float]: +) -> Tuple[List[int], Optional[List[int]], float, Optional[List[int]]]: """ - Get the columns of verts and verts_colors in the vertex + Get the columns of verts, verts_colors, and verts_normals in the vertex element of a parsed ply file, together with a color scale factor. When the colors are in byte format, they are scaled from 0..255 to [0,1]. Otherwise they are not scaled. @@ -793,11 +793,14 @@ def _get_verts_column_indices( property double x property double y property double z + property double nx + property double ny + property double nz property uchar red property uchar green property uchar blue - then the return value will be ([0,1,2], [6,7,8], 1.0/255) + then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5]) Args: vertex_head: as returned from load_ply_raw. @@ -807,9 +810,12 @@ def _get_verts_column_indices( color_idxs: List[int] of 3 color columns if they are present, otherwise None. color_scale: value to scale colors by. + normal_idxs: List[int] of 3 normals columns if they are present, + otherwise None. """ point_idxs: List[Optional[int]] = [None, None, None] color_idxs: List[Optional[int]] = [None, None, None] + normal_idxs: List[Optional[int]] = [None, None, None] for i, prop in enumerate(vertex_head.properties): if prop.list_size_type is not None: raise ValueError("Invalid vertices in file: did not expect list.") @@ -819,6 +825,9 @@ def _get_verts_column_indices( for j, name in enumerate(["red", "green", "blue"]): if prop.name == name: color_idxs[j] = i + for j, name in enumerate(["nx", "ny", "nz"]): + if prop.name == name: + normal_idxs[j] = i if None in point_idxs: raise ValueError("Invalid vertices in file.") color_scale = 1.0 @@ -831,14 +840,15 @@ def _get_verts_column_indices( point_idxs, None if None in color_idxs else cast(List[int], color_idxs), color_scale, + None if None in normal_idxs else cast(List[int], normal_idxs), ) def _get_verts( header: _PlyHeader, elements: dict -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ - Get the vertex locations and colors from a parsed ply file. + Get the vertex locations, colors and normals from a parsed ply file. Args: header, elements: as returned from load_ply_raw. @@ -846,6 +856,7 @@ def _get_verts( Returns: verts: FloatTensor of shape (V, 3). vertex_colors: None or FloatTensor of shape (V, 3). + vertex_normals: None or FloatTensor of shape (V, 3). """ vertex = elements.get("vertex", None) @@ -854,14 +865,16 @@ def _get_verts( if not isinstance(vertex, list): raise ValueError("Invalid vertices in file.") vertex_head = next(head for head in header.elements if head.name == "vertex") - point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head) + point_idxs, color_idxs, color_scale, normal_idxs = _get_verts_column_indices( + vertex_head + ) # Case of no vertices if vertex_head.count == 0: verts = torch.zeros((0, 3), dtype=torch.float32) if color_idxs is None: - return verts, None - return verts, torch.zeros((0, 3), dtype=torch.float32) + return verts, None, None + return verts, torch.zeros((0, 3), dtype=torch.float32), None # Simple case where the only data is the vertices themselves if ( @@ -870,9 +883,10 @@ def _get_verts( and vertex[0].ndim == 2 and vertex[0].shape[1] == 3 ): - return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None + return _make_tensor(vertex[0], cols=3, dtype=torch.float32), None, None vertex_colors = None + vertex_normals = None if len(vertex) == 1: # This is the case where the whole vertex element has one type, @@ -882,6 +896,10 @@ def _get_verts( vertex_colors = color_scale * torch.tensor( vertex[0][:, color_idxs], dtype=torch.float32 ) + if normal_idxs is not None: + vertex_normals = torch.tensor( + vertex[0][:, normal_idxs], dtype=torch.float32 + ) else: # The vertex element is heterogeneous. It was read as several arrays, # part by part, where a part is a set of properties with the same type. @@ -913,13 +931,22 @@ def _get_verts( partnum, col = prop_to_partnum_col[color_idxs[color]] vertex_colors.numpy()[:, color] = vertex[partnum][:, col] vertex_colors *= color_scale + if normal_idxs is not None: + vertex_normals = torch.empty( + size=(vertex_head.count, 3), dtype=torch.float32 + ) + for axis in range(3): + partnum, col = prop_to_partnum_col[normal_idxs[axis]] + vertex_normals.numpy()[:, axis] = vertex[partnum][:, col] - return verts, vertex_colors + return verts, vertex_colors, vertex_normals def _load_ply( f, *, path_manager: PathManager -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] +]: """ Load the data from a .ply file. @@ -935,10 +962,11 @@ def _load_ply( verts: FloatTensor of shape (V, 3). faces: None or LongTensor of vertex indices, shape (F, 3). vertex_colors: None or FloatTensor of shape (V, 3). + vertex_normals: None or FloatTensor of shape (V, 3). """ header, elements = _load_ply_raw(f, path_manager=path_manager) - verts, vertex_colors = _get_verts(header, elements) + verts, vertex_colors, vertex_normals = _get_verts(header, elements) face = elements.get("face", None) if face is not None: @@ -976,7 +1004,7 @@ def _load_ply( if faces is not None: _check_faces_indices(faces, max_index=verts.shape[0]) - return verts, faces, vertex_colors + return verts, faces, vertex_colors, vertex_normals def load_ply( @@ -1031,7 +1059,7 @@ def load_ply( if path_manager is None: path_manager = PathManager() - verts, faces, _ = _load_ply(f, path_manager=path_manager) + verts, faces, _, _ = _load_ply(f, path_manager=path_manager) if faces is None: faces = torch.zeros(0, 3, dtype=torch.int64) @@ -1211,18 +1239,23 @@ class MeshPlyFormat(MeshFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager) + verts, faces, verts_colors, verts_normals = _load_ply( + f=path, path_manager=path_manager + ) if faces is None: faces = torch.zeros(0, 3, dtype=torch.int64) - textures = None + texture = None if include_textures and verts_colors is not None: - textures = TexturesVertex([verts_colors.to(device)]) + texture = TexturesVertex([verts_colors.to(device)]) + if verts_normals is not None: + verts_normals = [verts_normals] mesh = Meshes( verts=[verts.to(device)], faces=[faces.to(device)], - textures=textures, + textures=texture, + verts_normals=verts_normals, ) return mesh @@ -1286,12 +1319,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): if not endswith(path, self.known_suffixes): return None - verts, faces, features = _load_ply(f=path, path_manager=path_manager) + verts, faces, features, normals = _load_ply(f=path, path_manager=path_manager) verts = verts.to(device) if features is not None: features = [features.to(device)] + if normals is not None: + normals = [normals.to(device)] - pointcloud = Pointclouds(points=[verts], features=features) + pointcloud = Pointclouds(points=[verts], features=features, normals=normals) return pointcloud def save( diff --git a/tests/test_io_ply.py b/tests/test_io_ply.py index 125c40fa..192170c8 100644 --- a/tests/test_io_ply.py +++ b/tests/test_io_ply.py @@ -216,14 +216,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 ) faces = torch.tensor([[0, 1, 2], [0, 2, 3]]) + normals = torch.tensor( + [[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32 + ) vert_colors = torch.rand_like(verts) texture = TexturesVertex(verts_features=[vert_colors]) - for do_textures in itertools.product([True, False]): + for do_textures, do_normals in itertools.product([True, False], [True, False]): mesh = Meshes( verts=[verts], faces=[faces], textures=texture if do_textures else None, + verts_normals=[normals] if do_normals else None, ) device = torch.device("cuda:0") @@ -236,12 +240,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): mesh2 = mesh2.cpu() self.assertClose(mesh2.verts_padded(), mesh.verts_padded()) self.assertClose(mesh2.faces_padded(), mesh.faces_padded()) + if do_normals: + self.assertTrue(mesh.has_verts_normals()) + self.assertTrue(mesh2.has_verts_normals()) + self.assertClose( + mesh2.verts_normals_padded(), mesh.verts_normals_padded() + ) + else: + self.assertFalse(mesh.has_verts_normals()) + self.assertFalse(mesh2.has_verts_normals()) + self.assertFalse(torch.allclose(mesh2.verts_normals_padded(), normals)) if do_textures: self.assertIsInstance(mesh2.textures, TexturesVertex) self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors) else: self.assertIsNone(mesh2.textures) + def test_save_load_with_normals(self): + points = torch.tensor( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 + ) + normals = torch.tensor( + [[0, 1, 0], [1, 0, 0], [1, 4, 1], [1, 0, 0]], dtype=torch.float32 + ) + features = torch.rand_like(points) + + for do_features, do_normals in itertools.product([True, False], [True, False]): + cloud = Pointclouds( + points=[points], + features=[features] if do_features else None, + normals=[normals] if do_normals else None, + ) + device = torch.device("cuda:0") + + io = IO() + with NamedTemporaryFile(mode="w", suffix=".ply") as f: + io.save_pointcloud(cloud.cuda(), f.name) + f.flush() + cloud2 = io.load_pointcloud(f.name, device=device) + self.assertEqual(cloud2.device, device) + cloud2 = cloud2.cpu() + self.assertClose(cloud2.points_padded(), cloud.points_padded()) + if do_normals: + self.assertClose(cloud2.normals_padded(), cloud.normals_padded()) + else: + self.assertIsNone(cloud.normals_padded()) + self.assertIsNone(cloud2.normals_padded()) + if do_features: + self.assertClose(cloud2.features_packed(), features) + else: + self.assertIsNone(cloud2.features_packed()) + def test_save_ply_invalid_shapes(self): # Invalid vertices shape with self.assertRaises(ValueError) as error: