Pointcloud normals estimation.

Summary: Estimates normals of a point cloud.

Reviewed By: gkioxari

Differential Revision: D20860182

fbshipit-source-id: 652ec2743fa645e02c01ffa37c2971bf27b89cef
This commit is contained in:
David Novotny 2020-04-16 18:33:43 -07:00 committed by Facebook GitHub Bot
parent 8abbe22ffb
commit 365945b1fd
7 changed files with 482 additions and 10 deletions

View File

@ -700,24 +700,38 @@ def load_ply(f):
return verts, faces
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
def _save_ply(
f,
verts: torch.Tensor,
faces: torch.LongTensor,
verts_normals: torch.Tensor,
decimal_places: Optional[int] = None,
) -> None:
"""
Internal implementation for saving a mesh to a .ply file.
Internal implementation for saving 3D data to a .ply file.
Args:
f: File object to which the mesh should be written.
f: File object to which the 3D data should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
faces: LongTensor of shsape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
"""
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
assert not len(verts_normals) or (
verts_normals.dim() == 2 and verts_normals.size(1) == 3
)
print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
print("property float x", file=f)
print("property float y", file=f)
print("property float z", file=f)
if verts_normals.numel() > 0:
print("property float nx", file=f)
print("property float ny", file=f)
print("property float nz", file=f)
print(f"element face {faces.shape[0]}", file=f)
print("property list uchar int vertex_index", file=f)
print("end_header", file=f)
@ -731,8 +745,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
else:
float_str = "%" + ".%df" % decimal_places
verts_array = verts.detach().numpy()
np.savetxt(f, verts_array, float_str)
vert_data = torch.cat((verts, verts_normals), dim=1)
np.savetxt(f, vert_data.detach().numpy(), float_str)
faces_array = faces.detach().numpy()
@ -743,7 +757,13 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
np.savetxt(f, faces_array, "3 %d %d %d")
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
def save_ply(
f,
verts: torch.Tensor,
faces: Optional[torch.LongTensor] = None,
verts_normals: Optional[torch.Tensor] = None,
decimal_places: Optional[int] = None,
) -> None:
"""
Save a mesh to a .ply file.
@ -751,8 +771,13 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
decimal_places: Number of decimal places for saving.
"""
verts_normals = torch.FloatTensor([]) if verts_normals is None else verts_normals
faces = torch.LongTensor([]) if faces is None else faces
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
@ -761,6 +786,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)
if len(verts_normals) and not (
verts_normals.dim() == 2
and verts_normals.size(1) == 3
and verts_normals.size(0) == verts.size(0)
):
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)
new_f = False
if isinstance(f, str):
new_f = True
@ -769,7 +802,7 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
new_f = True
f = f.open("w")
try:
_save_ply(f, verts, faces, decimal_places)
_save_ply(f, verts, faces, verts_normals, decimal_places)
finally:
if new_f:
f.close()

View File

@ -7,9 +7,19 @@ from .knn import knn_gather, knn_points
from .mesh_face_areas_normals import mesh_face_areas_normals
from .packed_to_padded import packed_to_padded, padded_to_packed
from .points_alignment import corresponding_points_alignment, iterative_closest_point
from .points_normals import (
estimate_pointcloud_local_coord_frames,
estimate_pointcloud_normals,
)
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
from .utils import (
convert_pointclouds_to_tensor,
eyes,
get_point_covariances,
is_pointclouds,
wmean,
)
from .vert_align import vert_align

View File

@ -0,0 +1,172 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import TYPE_CHECKING, Tuple, Union
import torch
from .utils import convert_pointclouds_to_tensor, get_point_covariances
if TYPE_CHECKING:
from ..structures import Pointclouds
def estimate_pointcloud_normals(
pointclouds: Union[torch.Tensor, "Pointclouds"],
neighborhood_size: int = 50,
disambiguate_directions: bool = True,
) -> torch.Tensor:
"""
Estimates the normals of a batch of `pointclouds`.
The function uses `estimate_pointcloud_local_coord_frames` to estimate
the normals. Please refer to this function for more detailed information.
Args:
**pointclouds**: Batch of 3-dimensional points of shape
`(minibatch, num_point, 3)` or a `Pointclouds` object.
**neighborhood_size**: The size of the neighborhood used to estimate the
geometry around each point.
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
ensure sign consistency of the normals of neigboring points.
Returns:
**normals**: A tensor of normals for each input point
of shape `(minibatch, num_point, 3)`.
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
References:
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
Local Surface Description, ECCV 2010.
"""
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
pointclouds,
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,
)
# the normals correspond to the first vector of each local coord frame
normals = local_coord_frames[:, :, :, 0]
return normals
def estimate_pointcloud_local_coord_frames(
pointclouds: Union[torch.Tensor, "Pointclouds"],
neighborhood_size: int = 50,
disambiguate_directions: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Estimates the principal directions of curvature (which includes normals)
of a batch of `pointclouds`.
The algorithm first finds `neighborhood_size` nearest neighbors for each
point of the point clouds, followed by obtaining principal vectors of
covariance matrices of each of the point neighborhoods.
The main principal vector corresponds to the normals, while the
other 2 are the direction of the highest curvature and the 2nd highest
curvature.
Note that each principal direction is given up to a sign. Hence,
the function implements `disambiguate_directions` switch that allows
to ensure consistency of the sign of neighboring normals. The implementation
follows the sign disabiguation from SHOT descriptors [1].
The algorithm also returns the curvature values themselves.
These are the eigenvalues of the estimated covariance matrices
of each point neighborhood.
Args:
**pointclouds**: Batch of 3-dimensional points of shape
`(minibatch, num_point, 3)` or a `Pointclouds` object.
**neighborhood_size**: The size of the neighborhood used to estimate the
geometry around each point.
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
ensure sign consistency of the normals of neigboring points.
Returns:
**curvatures**: The three principal curvatures of each point
of shape `(minibatch, num_point, 3)`.
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
**local_coord_frames**: The three principal directions of the curvature
around each point of shape `(minibatch, num_point, 3, 3)`.
The principal directions are stored in columns of the output.
E.g. `local_coord_frames[i, j, :, 0]` is the normal of
`j`-th point in the `i`-th pointcloud.
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
References:
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
Local Surface Description, ECCV 2010.
"""
points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)
ba, N, dim = points_padded.shape
if dim != 3:
raise ValueError(
"The pointclouds argument has to be of shape (minibatch, N, 3)"
)
if (num_points <= neighborhood_size).any():
raise ValueError(
"The neighborhood_size argument has to be"
+ " >= size of each of the point clouds."
)
# undo global mean for stability
# TODO: replace with tutil.wmean once landed
pcl_mean = points_padded.sum(1) / num_points[:, None]
points_centered = points_padded - pcl_mean[:, None, :]
# get the per-point covariance and nearest neighbors used to compute it
cov, knns = get_point_covariances(points_centered, num_points, neighborhood_size)
# get the local coord frames as principal directions of
# the per-point covariance
# this is done with torch.symeig, which returns the
# eigenvectors (=principal directions) in an ascending order of their
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
# corresponds to the normal direction
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
# disambiguate the directions of individual principal vectors
if disambiguate_directions:
# disambiguate normal
n = _disambiguate_vector_directions(
points_centered, knns, local_coord_frames[:, :, :, 0]
)
# disambiguate the main curvature
z = _disambiguate_vector_directions(
points_centered, knns, local_coord_frames[:, :, :, 2]
)
# the secondary curvature is just a cross between n and z
y = torch.cross(n, z, dim=2)
# cat to form the set of principal directions
local_coord_frames = torch.stack((n, y, z), dim=3)
return curvatures, local_coord_frames
def _disambiguate_vector_directions(pcl, knns, vecs):
"""
Disambiguates normal directions according to [1].
References:
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
Local Surface Description, ECCV 2010.
"""
# parse out K from the shape of knns
K = knns.shape[2]
# the difference between the mean of each neighborhood and
# each element of the neighborhood
df = knns - pcl[:, :, None]
# projection of the difference on the principal direction
proj = (vecs[:, :, None] * df).sum(3)
# check how many projections are positive
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)
# flip the principal directions where number of positive correlations
flip = (n_pos < (0.5 * K)).type_as(knns)
vecs = (1.0 - 2.0 * flip) * vecs
return vecs

View File

@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
from .knn import knn_points
if TYPE_CHECKING:
from pytorch3d.structures import Pointclouds
@ -92,8 +94,53 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
""" Checks whether the input `pcl` is an instance `Pointclouds` of
""" Checks whether the input `pcl` is an instance of `Pointclouds`
by checking the existence of `points_padded` and `num_points_per_cloud`
functions.
"""
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
def get_point_covariances(
points_padded: torch.Tensor,
num_points_per_cloud: torch.Tensor,
neighborhood_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the per-point covariance matrices by of the 3D locations of
K-nearest neighbors of each point.
Args:
**points_padded**: Input point clouds as a padded tensor
of shape `(minibatch, num_points, dim)`.
**num_points_per_cloud**: Number of points per cloud
of shape `(minibatch,)`.
**neighborhood_size**: Number of nearest neighbors for each point
used to estimate the covariance matrices.
Returns:
**covariances**: A batch of per-point covariance matrices
of shape `(minibatch, dim, dim)`.
**k_nearest_neighbors**: A batch of `neighborhood_size` nearest
neighbors for each of the point cloud points
of shape `(minibatch, num_points, neighborhood_size, dim)`.
"""
# get K nearest neighbor idx for each point in the point cloud
_, _, k_nearest_neighbors = knn_points(
points_padded,
points_padded,
lengths1=num_points_per_cloud,
lengths2=num_points_per_cloud,
K=neighborhood_size,
return_nn=True,
)
# obtain the mean of the neighborhood
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
# compute the diff of the neighborhood and the mean of the neighborhood
central_diff = k_nearest_neighbors - pt_mean
# per-nn-point covariances
per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
# per-point covariances
covariances = per_pt_cov.mean(2)
return covariances, k_nearest_neighbors

View File

@ -2,6 +2,7 @@
import torch
from .. import ops
from . import utils as struct_utils
@ -847,6 +848,51 @@ class Pointclouds(object):
bboxes = torch.stack([all_mins, all_maxes], dim=2)
return bboxes
def estimate_normals(
self,
neighborhood_size: int = 50,
disambiguate_directions: bool = True,
assign_to_self: bool = False,
):
"""
Estimates the normals of each point in each cloud and assigns
them to the internal tensors `self._normals_list` and `self._normals_padded`
The function uses `ops.estimate_pointcloud_local_coord_frames`
to estimate the normals. Please refer to this function for more
detailed information about the implemented algorithm.
Args:
**neighborhood_size**: The size of the neighborhood used to estimate the
geometry around each point.
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
ensure sign consistency of the normals of neigboring points.
**normals**: A tensor of normals for each input point
of shape `(minibatch, num_point, 3)`.
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
**assign_to_self**: If `True`, assigns the computed normals to the
internal buffers overwriting any previously stored normals.
References:
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
Local Surface Description, ECCV 2010.
"""
# estimate the normals
normals_est = ops.estimate_pointcloud_normals(
self,
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,
)
# assign to self
if assign_to_self:
self._normals_list, self._normals_padded, _ = self._parse_auxiliary_input(
normals_est
)
return normals_est
def extend(self, N: int):
"""
Create new Pointclouds which contains each cloud N times.

View File

@ -185,6 +185,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
self.assertClose(expected_verts, actual_verts)
self.assertClose(expected_faces, actual_faces)
def test_normals_save(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
normals = torch.tensor(
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
)
file = StringIO()
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
file.close()
def test_empty_save_load(self):
# Vertices + empty faces
verts = torch.tensor([[0.1, 0.2, 0.3]])

View File

@ -0,0 +1,152 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from typing import Tuple, Union
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import (
estimate_pointcloud_local_coord_frames,
estimate_pointcloud_normals,
)
from pytorch3d.structures.pointclouds import Pointclouds
DEBUG = False
class TestPCLNormals(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
@staticmethod
def init_spherical_pcl(
batch_size=3, num_points=3000, device=None, use_pointclouds=False
) -> Tuple[Union[torch.Tensor, Pointclouds], torch.Tensor]:
# random spherical point cloud
pcl = torch.randn(
(batch_size, num_points, 3), device=device, dtype=torch.float32
)
pcl = torch.nn.functional.normalize(pcl, dim=2)
# GT normals are the same as
# the location of each point on the 0-centered sphere
normals = pcl.clone()
# scale and offset the sphere randomly
pcl *= torch.rand(batch_size, 1, 1).type_as(pcl) + 1.0
pcl += torch.randn(batch_size, 1, 3).type_as(pcl)
if use_pointclouds:
num_points = torch.randint(
size=(batch_size,), low=int(num_points * 0.7), high=num_points
)
pcl, normals = [
[x[:np] for x, np in zip(X, num_points)] for X in (pcl, normals)
]
pcl = Pointclouds(pcl, normals=normals)
return pcl, normals
def test_pcl_normals(self, batch_size=3, num_points=300, neighborhood_size=50):
"""
Tests the normal estimation on a spherical point cloud, where
we know the ground truth normals.
"""
device = torch.device("cuda:0")
# run several times for different random point clouds
for run_idx in range(3):
# either use tensors or Pointclouds as input
for use_pointclouds in (True, False):
# get a spherical point cloud
pcl, normals_gt = TestPCLNormals.init_spherical_pcl(
num_points=num_points,
batch_size=batch_size,
device=device,
use_pointclouds=use_pointclouds,
)
if use_pointclouds:
normals_gt = pcl.normals_padded()
num_pcl_points = pcl.num_points_per_cloud()
else:
num_pcl_points = [pcl.shape[1]] * batch_size
# check for both disambiguation options
for disambiguate_directions in (True, False):
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
pcl,
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,
)
# estimate the normals
normals = estimate_pointcloud_normals(
pcl,
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,
)
# TODO: temporarily disabled
if use_pointclouds:
# test that the class method gives the same output
normals_pcl = pcl.estimate_normals(
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,
assign_to_self=True,
)
normals_from_pcl = pcl.normals_padded()
for nrm, nrm_from_pcl, nrm_pcl, np in zip(
normals, normals_from_pcl, normals_pcl, num_pcl_points
):
self.assertClose(nrm[:np], nrm_pcl[:np], atol=1e-5)
self.assertClose(nrm[:np], nrm_from_pcl[:np], atol=1e-5)
# check that local coord frames give the same normal
# as normals
for nrm, lcoord, np in zip(
normals, local_coord_frames, num_pcl_points
):
self.assertClose(nrm[:np], lcoord[:np, :, 0], atol=1e-5)
# dotp between normals and normals_gt
normal_parallel = (normals_gt * normals).sum(2)
# check that normals are on average
# parallel to the expected ones
for normp, np in zip(normal_parallel, num_pcl_points):
abs_parallel = normp[:np].abs()
avg_parallel = abs_parallel.mean()
std_parallel = abs_parallel.std()
self.assertClose(
avg_parallel, torch.ones_like(avg_parallel), atol=1e-2
)
self.assertClose(
std_parallel, torch.zeros_like(std_parallel), atol=1e-2
)
if disambiguate_directions:
# check that 95% of normal dot products
# have the same sign
for normp, np in zip(normal_parallel, num_pcl_points):
n_pos = (normp[:np] > 0).sum()
self.assertTrue((n_pos > np * 0.95) or (n_pos < np * 0.05))
if DEBUG and run_idx == 0 and not use_pointclouds:
import os
from pytorch3d.io.ply_io import save_ply
# export to .ply
outdir = "/tmp/pt3d_pcl_normals_test/"
os.makedirs(outdir, exist_ok=True)
plyfile = os.path.join(
outdir, f"pcl_disamb={disambiguate_directions}.ply"
)
print(f"Storing point cloud with normals to {plyfile}.")
pcl_idx = 0
save_ply(
plyfile,
pcl[pcl_idx].cpu(),
faces=None,
verts_normals=normals[pcl_idx].cpu(),
)