diff --git a/pytorch3d/io/ply_io.py b/pytorch3d/io/ply_io.py index 8c4182dd..2ceaaa15 100644 --- a/pytorch3d/io/ply_io.py +++ b/pytorch3d/io/ply_io.py @@ -700,24 +700,38 @@ def load_ply(f): return verts, faces -def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None: +def _save_ply( + f, + verts: torch.Tensor, + faces: torch.LongTensor, + verts_normals: torch.Tensor, + decimal_places: Optional[int] = None, +) -> None: """ - Internal implementation for saving a mesh to a .ply file. + Internal implementation for saving 3D data to a .ply file. Args: - f: File object to which the mesh should be written. + f: File object to which the 3D data should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. - faces: LongTensor of shape (F, 3) giving faces. + faces: LongTensor of shsape (F, 3) giving faces. + verts_normals: FloatTensor of shape (V, 3) giving vertex normals. decimal_places: Number of decimal places for saving. """ assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) + assert not len(verts_normals) or ( + verts_normals.dim() == 2 and verts_normals.size(1) == 3 + ) print("ply\nformat ascii 1.0", file=f) print(f"element vertex {verts.shape[0]}", file=f) print("property float x", file=f) print("property float y", file=f) print("property float z", file=f) + if verts_normals.numel() > 0: + print("property float nx", file=f) + print("property float ny", file=f) + print("property float nz", file=f) print(f"element face {faces.shape[0]}", file=f) print("property list uchar int vertex_index", file=f) print("end_header", file=f) @@ -731,8 +745,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None: else: float_str = "%" + ".%df" % decimal_places - verts_array = verts.detach().numpy() - np.savetxt(f, verts_array, float_str) + vert_data = torch.cat((verts, verts_normals), dim=1) + np.savetxt(f, vert_data.detach().numpy(), float_str) faces_array = faces.detach().numpy() @@ -743,7 +757,13 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None: np.savetxt(f, faces_array, "3 %d %d %d") -def save_ply(f, verts, faces, decimal_places: Optional[int] = None): +def save_ply( + f, + verts: torch.Tensor, + faces: Optional[torch.LongTensor] = None, + verts_normals: Optional[torch.Tensor] = None, + decimal_places: Optional[int] = None, +) -> None: """ Save a mesh to a .ply file. @@ -751,8 +771,13 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None): f: File (or path) to which the mesh should be written. verts: FloatTensor of shape (V, 3) giving vertex coordinates. faces: LongTensor of shape (F, 3) giving faces. + verts_normals: FloatTensor of shape (V, 3) giving vertex normals. decimal_places: Number of decimal places for saving. """ + + verts_normals = torch.FloatTensor([]) if verts_normals is None else verts_normals + faces = torch.LongTensor([]) if faces is None else faces + if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." raise ValueError(message) @@ -761,6 +786,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None): message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." raise ValueError(message) + if len(verts_normals) and not ( + verts_normals.dim() == 2 + and verts_normals.size(1) == 3 + and verts_normals.size(0) == verts.size(0) + ): + message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." + raise ValueError(message) + new_f = False if isinstance(f, str): new_f = True @@ -769,7 +802,7 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None): new_f = True f = f.open("w") try: - _save_ply(f, verts, faces, decimal_places) + _save_ply(f, verts, faces, verts_normals, decimal_places) finally: if new_f: f.close() diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index fe522d3d..dfbcec21 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -7,9 +7,19 @@ from .knn import knn_gather, knn_points from .mesh_face_areas_normals import mesh_face_areas_normals from .packed_to_padded import packed_to_padded, padded_to_packed from .points_alignment import corresponding_points_alignment, iterative_closest_point +from .points_normals import ( + estimate_pointcloud_local_coord_frames, + estimate_pointcloud_normals, +) from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes -from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean +from .utils import ( + convert_pointclouds_to_tensor, + eyes, + get_point_covariances, + is_pointclouds, + wmean, +) from .vert_align import vert_align diff --git a/pytorch3d/ops/points_normals.py b/pytorch3d/ops/points_normals.py new file mode 100644 index 00000000..09a082e5 --- /dev/null +++ b/pytorch3d/ops/points_normals.py @@ -0,0 +1,172 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import TYPE_CHECKING, Tuple, Union + +import torch + +from .utils import convert_pointclouds_to_tensor, get_point_covariances + + +if TYPE_CHECKING: + from ..structures import Pointclouds + + +def estimate_pointcloud_normals( + pointclouds: Union[torch.Tensor, "Pointclouds"], + neighborhood_size: int = 50, + disambiguate_directions: bool = True, +) -> torch.Tensor: + """ + Estimates the normals of a batch of `pointclouds`. + + The function uses `estimate_pointcloud_local_coord_frames` to estimate + the normals. Please refer to this function for more detailed information. + + Args: + **pointclouds**: Batch of 3-dimensional points of shape + `(minibatch, num_point, 3)` or a `Pointclouds` object. + **neighborhood_size**: The size of the neighborhood used to estimate the + geometry around each point. + **disambiguate_directions**: If `True`, uses the algorithm from [1] to + ensure sign consistency of the normals of neigboring points. + + Returns: + **normals**: A tensor of normals for each input point + of shape `(minibatch, num_point, 3)`. + If `pointclouds` are of `Pointclouds` class, returns a padded tensor. + + References: + [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for + Local Surface Description, ECCV 2010. + """ + + curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames( + pointclouds, + neighborhood_size=neighborhood_size, + disambiguate_directions=disambiguate_directions, + ) + + # the normals correspond to the first vector of each local coord frame + normals = local_coord_frames[:, :, :, 0] + + return normals + + +def estimate_pointcloud_local_coord_frames( + pointclouds: Union[torch.Tensor, "Pointclouds"], + neighborhood_size: int = 50, + disambiguate_directions: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Estimates the principal directions of curvature (which includes normals) + of a batch of `pointclouds`. + + The algorithm first finds `neighborhood_size` nearest neighbors for each + point of the point clouds, followed by obtaining principal vectors of + covariance matrices of each of the point neighborhoods. + The main principal vector corresponds to the normals, while the + other 2 are the direction of the highest curvature and the 2nd highest + curvature. + + Note that each principal direction is given up to a sign. Hence, + the function implements `disambiguate_directions` switch that allows + to ensure consistency of the sign of neighboring normals. The implementation + follows the sign disabiguation from SHOT descriptors [1]. + + The algorithm also returns the curvature values themselves. + These are the eigenvalues of the estimated covariance matrices + of each point neighborhood. + + Args: + **pointclouds**: Batch of 3-dimensional points of shape + `(minibatch, num_point, 3)` or a `Pointclouds` object. + **neighborhood_size**: The size of the neighborhood used to estimate the + geometry around each point. + **disambiguate_directions**: If `True`, uses the algorithm from [1] to + ensure sign consistency of the normals of neigboring points. + + Returns: + **curvatures**: The three principal curvatures of each point + of shape `(minibatch, num_point, 3)`. + If `pointclouds` are of `Pointclouds` class, returns a padded tensor. + **local_coord_frames**: The three principal directions of the curvature + around each point of shape `(minibatch, num_point, 3, 3)`. + The principal directions are stored in columns of the output. + E.g. `local_coord_frames[i, j, :, 0]` is the normal of + `j`-th point in the `i`-th pointcloud. + If `pointclouds` are of `Pointclouds` class, returns a padded tensor. + + References: + [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for + Local Surface Description, ECCV 2010. + """ + + points_padded, num_points = convert_pointclouds_to_tensor(pointclouds) + + ba, N, dim = points_padded.shape + if dim != 3: + raise ValueError( + "The pointclouds argument has to be of shape (minibatch, N, 3)" + ) + + if (num_points <= neighborhood_size).any(): + raise ValueError( + "The neighborhood_size argument has to be" + + " >= size of each of the point clouds." + ) + + # undo global mean for stability + # TODO: replace with tutil.wmean once landed + pcl_mean = points_padded.sum(1) / num_points[:, None] + points_centered = points_padded - pcl_mean[:, None, :] + + # get the per-point covariance and nearest neighbors used to compute it + cov, knns = get_point_covariances(points_centered, num_points, neighborhood_size) + + # get the local coord frames as principal directions of + # the per-point covariance + # this is done with torch.symeig, which returns the + # eigenvectors (=principal directions) in an ascending order of their + # corresponding eigenvalues, while the smallest eigenvalue's eigenvector + # corresponds to the normal direction + curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True) + + # disambiguate the directions of individual principal vectors + if disambiguate_directions: + # disambiguate normal + n = _disambiguate_vector_directions( + points_centered, knns, local_coord_frames[:, :, :, 0] + ) + # disambiguate the main curvature + z = _disambiguate_vector_directions( + points_centered, knns, local_coord_frames[:, :, :, 2] + ) + # the secondary curvature is just a cross between n and z + y = torch.cross(n, z, dim=2) + # cat to form the set of principal directions + local_coord_frames = torch.stack((n, y, z), dim=3) + + return curvatures, local_coord_frames + + +def _disambiguate_vector_directions(pcl, knns, vecs): + """ + Disambiguates normal directions according to [1]. + + References: + [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for + Local Surface Description, ECCV 2010. + """ + # parse out K from the shape of knns + K = knns.shape[2] + # the difference between the mean of each neighborhood and + # each element of the neighborhood + df = knns - pcl[:, :, None] + # projection of the difference on the principal direction + proj = (vecs[:, :, None] * df).sum(3) + # check how many projections are positive + n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True) + # flip the principal directions where number of positive correlations + flip = (n_pos < (0.5 * K)).type_as(knns) + vecs = (1.0 - 2.0 * flip) * vecs + return vecs diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py index 134172b0..41e00ace 100644 --- a/pytorch3d/ops/utils.py +++ b/pytorch3d/ops/utils.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union import torch +from .knn import knn_points + if TYPE_CHECKING: from pytorch3d.structures import Pointclouds @@ -92,8 +94,53 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]): def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]): - """ Checks whether the input `pcl` is an instance `Pointclouds` of + """ Checks whether the input `pcl` is an instance of `Pointclouds` by checking the existence of `points_padded` and `num_points_per_cloud` functions. """ return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud") + + +def get_point_covariances( + points_padded: torch.Tensor, + num_points_per_cloud: torch.Tensor, + neighborhood_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the per-point covariance matrices by of the 3D locations of + K-nearest neighbors of each point. + + Args: + **points_padded**: Input point clouds as a padded tensor + of shape `(minibatch, num_points, dim)`. + **num_points_per_cloud**: Number of points per cloud + of shape `(minibatch,)`. + **neighborhood_size**: Number of nearest neighbors for each point + used to estimate the covariance matrices. + + Returns: + **covariances**: A batch of per-point covariance matrices + of shape `(minibatch, dim, dim)`. + **k_nearest_neighbors**: A batch of `neighborhood_size` nearest + neighbors for each of the point cloud points + of shape `(minibatch, num_points, neighborhood_size, dim)`. + """ + # get K nearest neighbor idx for each point in the point cloud + _, _, k_nearest_neighbors = knn_points( + points_padded, + points_padded, + lengths1=num_points_per_cloud, + lengths2=num_points_per_cloud, + K=neighborhood_size, + return_nn=True, + ) + # obtain the mean of the neighborhood + pt_mean = k_nearest_neighbors.mean(2, keepdim=True) + # compute the diff of the neighborhood and the mean of the neighborhood + central_diff = k_nearest_neighbors - pt_mean + # per-nn-point covariances + per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3) + # per-point covariances + covariances = per_pt_cov.mean(2) + + return covariances, k_nearest_neighbors diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 23ff0e44..be88acc1 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -2,6 +2,7 @@ import torch +from .. import ops from . import utils as struct_utils @@ -847,6 +848,51 @@ class Pointclouds(object): bboxes = torch.stack([all_mins, all_maxes], dim=2) return bboxes + def estimate_normals( + self, + neighborhood_size: int = 50, + disambiguate_directions: bool = True, + assign_to_self: bool = False, + ): + """ + Estimates the normals of each point in each cloud and assigns + them to the internal tensors `self._normals_list` and `self._normals_padded` + + The function uses `ops.estimate_pointcloud_local_coord_frames` + to estimate the normals. Please refer to this function for more + detailed information about the implemented algorithm. + + Args: + **neighborhood_size**: The size of the neighborhood used to estimate the + geometry around each point. + **disambiguate_directions**: If `True`, uses the algorithm from [1] to + ensure sign consistency of the normals of neigboring points. + **normals**: A tensor of normals for each input point + of shape `(minibatch, num_point, 3)`. + If `pointclouds` are of `Pointclouds` class, returns a padded tensor. + **assign_to_self**: If `True`, assigns the computed normals to the + internal buffers overwriting any previously stored normals. + + References: + [1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for + Local Surface Description, ECCV 2010. + """ + + # estimate the normals + normals_est = ops.estimate_pointcloud_normals( + self, + neighborhood_size=neighborhood_size, + disambiguate_directions=disambiguate_directions, + ) + + # assign to self + if assign_to_self: + self._normals_list, self._normals_padded, _ = self._parse_auxiliary_input( + normals_est + ) + + return normals_est + def extend(self, N: int): """ Create new Pointclouds which contains each cloud N times. diff --git a/tests/test_ply_io.py b/tests/test_ply_io.py index 9d7e058b..73dfd78d 100644 --- a/tests/test_ply_io.py +++ b/tests/test_ply_io.py @@ -185,6 +185,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): self.assertClose(expected_verts, actual_verts) self.assertClose(expected_faces, actual_faces) + def test_normals_save(self): + verts = torch.tensor( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32 + ) + faces = torch.tensor([[0, 1, 2], [0, 2, 3]]) + normals = torch.tensor( + [[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32 + ) + file = StringIO() + save_ply(file, verts=verts, faces=faces, verts_normals=normals) + file.close() + def test_empty_save_load(self): # Vertices + empty faces verts = torch.tensor([[0.1, 0.2, 0.3]]) diff --git a/tests/test_points_normals.py b/tests/test_points_normals.py new file mode 100644 index 00000000..ad757425 --- /dev/null +++ b/tests/test_points_normals.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest +from typing import Tuple, Union + +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import ( + estimate_pointcloud_local_coord_frames, + estimate_pointcloud_normals, +) +from pytorch3d.structures.pointclouds import Pointclouds + + +DEBUG = False + + +class TestPCLNormals(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + @staticmethod + def init_spherical_pcl( + batch_size=3, num_points=3000, device=None, use_pointclouds=False + ) -> Tuple[Union[torch.Tensor, Pointclouds], torch.Tensor]: + # random spherical point cloud + pcl = torch.randn( + (batch_size, num_points, 3), device=device, dtype=torch.float32 + ) + pcl = torch.nn.functional.normalize(pcl, dim=2) + + # GT normals are the same as + # the location of each point on the 0-centered sphere + normals = pcl.clone() + + # scale and offset the sphere randomly + pcl *= torch.rand(batch_size, 1, 1).type_as(pcl) + 1.0 + pcl += torch.randn(batch_size, 1, 3).type_as(pcl) + + if use_pointclouds: + num_points = torch.randint( + size=(batch_size,), low=int(num_points * 0.7), high=num_points + ) + pcl, normals = [ + [x[:np] for x, np in zip(X, num_points)] for X in (pcl, normals) + ] + pcl = Pointclouds(pcl, normals=normals) + + return pcl, normals + + def test_pcl_normals(self, batch_size=3, num_points=300, neighborhood_size=50): + """ + Tests the normal estimation on a spherical point cloud, where + we know the ground truth normals. + """ + device = torch.device("cuda:0") + # run several times for different random point clouds + for run_idx in range(3): + # either use tensors or Pointclouds as input + for use_pointclouds in (True, False): + # get a spherical point cloud + pcl, normals_gt = TestPCLNormals.init_spherical_pcl( + num_points=num_points, + batch_size=batch_size, + device=device, + use_pointclouds=use_pointclouds, + ) + if use_pointclouds: + normals_gt = pcl.normals_padded() + num_pcl_points = pcl.num_points_per_cloud() + else: + num_pcl_points = [pcl.shape[1]] * batch_size + + # check for both disambiguation options + for disambiguate_directions in (True, False): + curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames( + pcl, + neighborhood_size=neighborhood_size, + disambiguate_directions=disambiguate_directions, + ) + + # estimate the normals + normals = estimate_pointcloud_normals( + pcl, + neighborhood_size=neighborhood_size, + disambiguate_directions=disambiguate_directions, + ) + + # TODO: temporarily disabled + if use_pointclouds: + # test that the class method gives the same output + normals_pcl = pcl.estimate_normals( + neighborhood_size=neighborhood_size, + disambiguate_directions=disambiguate_directions, + assign_to_self=True, + ) + normals_from_pcl = pcl.normals_padded() + for nrm, nrm_from_pcl, nrm_pcl, np in zip( + normals, normals_from_pcl, normals_pcl, num_pcl_points + ): + self.assertClose(nrm[:np], nrm_pcl[:np], atol=1e-5) + self.assertClose(nrm[:np], nrm_from_pcl[:np], atol=1e-5) + + # check that local coord frames give the same normal + # as normals + for nrm, lcoord, np in zip( + normals, local_coord_frames, num_pcl_points + ): + self.assertClose(nrm[:np], lcoord[:np, :, 0], atol=1e-5) + + # dotp between normals and normals_gt + normal_parallel = (normals_gt * normals).sum(2) + + # check that normals are on average + # parallel to the expected ones + for normp, np in zip(normal_parallel, num_pcl_points): + abs_parallel = normp[:np].abs() + avg_parallel = abs_parallel.mean() + std_parallel = abs_parallel.std() + self.assertClose( + avg_parallel, torch.ones_like(avg_parallel), atol=1e-2 + ) + self.assertClose( + std_parallel, torch.zeros_like(std_parallel), atol=1e-2 + ) + + if disambiguate_directions: + # check that 95% of normal dot products + # have the same sign + for normp, np in zip(normal_parallel, num_pcl_points): + n_pos = (normp[:np] > 0).sum() + self.assertTrue((n_pos > np * 0.95) or (n_pos < np * 0.05)) + + if DEBUG and run_idx == 0 and not use_pointclouds: + import os + from pytorch3d.io.ply_io import save_ply + + # export to .ply + outdir = "/tmp/pt3d_pcl_normals_test/" + os.makedirs(outdir, exist_ok=True) + plyfile = os.path.join( + outdir, f"pcl_disamb={disambiguate_directions}.ply" + ) + print(f"Storing point cloud with normals to {plyfile}.") + pcl_idx = 0 + save_ply( + plyfile, + pcl[pcl_idx].cpu(), + faces=None, + verts_normals=normals[pcl_idx].cpu(), + )