mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Pointcloud normals estimation.
Summary: Estimates normals of a point cloud. Reviewed By: gkioxari Differential Revision: D20860182 fbshipit-source-id: 652ec2743fa645e02c01ffa37c2971bf27b89cef
This commit is contained in:
parent
8abbe22ffb
commit
365945b1fd
@ -700,24 +700,38 @@ def load_ply(f):
|
||||
return verts, faces
|
||||
|
||||
|
||||
def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
|
||||
def _save_ply(
|
||||
f,
|
||||
verts: torch.Tensor,
|
||||
faces: torch.LongTensor,
|
||||
verts_normals: torch.Tensor,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Internal implementation for saving a mesh to a .ply file.
|
||||
Internal implementation for saving 3D data to a .ply file.
|
||||
|
||||
Args:
|
||||
f: File object to which the mesh should be written.
|
||||
f: File object to which the 3D data should be written.
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
faces: LongTensor of shsape (F, 3) giving faces.
|
||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
|
||||
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
|
||||
assert not len(verts_normals) or (
|
||||
verts_normals.dim() == 2 and verts_normals.size(1) == 3
|
||||
)
|
||||
|
||||
print("ply\nformat ascii 1.0", file=f)
|
||||
print(f"element vertex {verts.shape[0]}", file=f)
|
||||
print("property float x", file=f)
|
||||
print("property float y", file=f)
|
||||
print("property float z", file=f)
|
||||
if verts_normals.numel() > 0:
|
||||
print("property float nx", file=f)
|
||||
print("property float ny", file=f)
|
||||
print("property float nz", file=f)
|
||||
print(f"element face {faces.shape[0]}", file=f)
|
||||
print("property list uchar int vertex_index", file=f)
|
||||
print("end_header", file=f)
|
||||
@ -731,8 +745,8 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
|
||||
else:
|
||||
float_str = "%" + ".%df" % decimal_places
|
||||
|
||||
verts_array = verts.detach().numpy()
|
||||
np.savetxt(f, verts_array, float_str)
|
||||
vert_data = torch.cat((verts, verts_normals), dim=1)
|
||||
np.savetxt(f, vert_data.detach().numpy(), float_str)
|
||||
|
||||
faces_array = faces.detach().numpy()
|
||||
|
||||
@ -743,7 +757,13 @@ def _save_ply(f, verts, faces, decimal_places: Optional[int]) -> None:
|
||||
np.savetxt(f, faces_array, "3 %d %d %d")
|
||||
|
||||
|
||||
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
def save_ply(
|
||||
f,
|
||||
verts: torch.Tensor,
|
||||
faces: Optional[torch.LongTensor] = None,
|
||||
verts_normals: Optional[torch.Tensor] = None,
|
||||
decimal_places: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save a mesh to a .ply file.
|
||||
|
||||
@ -751,8 +771,13 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
f: File (or path) to which the mesh should be written.
|
||||
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
|
||||
faces: LongTensor of shape (F, 3) giving faces.
|
||||
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
|
||||
decimal_places: Number of decimal places for saving.
|
||||
"""
|
||||
|
||||
verts_normals = torch.FloatTensor([]) if verts_normals is None else verts_normals
|
||||
faces = torch.LongTensor([]) if faces is None else faces
|
||||
|
||||
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
|
||||
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
@ -761,6 +786,14 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
if len(verts_normals) and not (
|
||||
verts_normals.dim() == 2
|
||||
and verts_normals.size(1) == 3
|
||||
and verts_normals.size(0) == verts.size(0)
|
||||
):
|
||||
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
|
||||
raise ValueError(message)
|
||||
|
||||
new_f = False
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
@ -769,7 +802,7 @@ def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
new_f = True
|
||||
f = f.open("w")
|
||||
try:
|
||||
_save_ply(f, verts, faces, decimal_places)
|
||||
_save_ply(f, verts, faces, verts_normals, decimal_places)
|
||||
finally:
|
||||
if new_f:
|
||||
f.close()
|
||||
|
@ -7,9 +7,19 @@ from .knn import knn_gather, knn_points
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
from .points_alignment import corresponding_points_alignment, iterative_closest_point
|
||||
from .points_normals import (
|
||||
estimate_pointcloud_local_coord_frames,
|
||||
estimate_pointcloud_normals,
|
||||
)
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
|
||||
from .utils import (
|
||||
convert_pointclouds_to_tensor,
|
||||
eyes,
|
||||
get_point_covariances,
|
||||
is_pointclouds,
|
||||
wmean,
|
||||
)
|
||||
from .vert_align import vert_align
|
||||
|
||||
|
||||
|
172
pytorch3d/ops/points_normals.py
Normal file
172
pytorch3d/ops/points_normals.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..structures import Pointclouds
|
||||
|
||||
|
||||
def estimate_pointcloud_normals(
|
||||
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
||||
neighborhood_size: int = 50,
|
||||
disambiguate_directions: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Estimates the normals of a batch of `pointclouds`.
|
||||
|
||||
The function uses `estimate_pointcloud_local_coord_frames` to estimate
|
||||
the normals. Please refer to this function for more detailed information.
|
||||
|
||||
Args:
|
||||
**pointclouds**: Batch of 3-dimensional points of shape
|
||||
`(minibatch, num_point, 3)` or a `Pointclouds` object.
|
||||
**neighborhood_size**: The size of the neighborhood used to estimate the
|
||||
geometry around each point.
|
||||
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
||||
ensure sign consistency of the normals of neigboring points.
|
||||
|
||||
Returns:
|
||||
**normals**: A tensor of normals for each input point
|
||||
of shape `(minibatch, num_point, 3)`.
|
||||
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
|
||||
|
||||
References:
|
||||
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
|
||||
Local Surface Description, ECCV 2010.
|
||||
"""
|
||||
|
||||
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
|
||||
pointclouds,
|
||||
neighborhood_size=neighborhood_size,
|
||||
disambiguate_directions=disambiguate_directions,
|
||||
)
|
||||
|
||||
# the normals correspond to the first vector of each local coord frame
|
||||
normals = local_coord_frames[:, :, :, 0]
|
||||
|
||||
return normals
|
||||
|
||||
|
||||
def estimate_pointcloud_local_coord_frames(
|
||||
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
||||
neighborhood_size: int = 50,
|
||||
disambiguate_directions: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Estimates the principal directions of curvature (which includes normals)
|
||||
of a batch of `pointclouds`.
|
||||
|
||||
The algorithm first finds `neighborhood_size` nearest neighbors for each
|
||||
point of the point clouds, followed by obtaining principal vectors of
|
||||
covariance matrices of each of the point neighborhoods.
|
||||
The main principal vector corresponds to the normals, while the
|
||||
other 2 are the direction of the highest curvature and the 2nd highest
|
||||
curvature.
|
||||
|
||||
Note that each principal direction is given up to a sign. Hence,
|
||||
the function implements `disambiguate_directions` switch that allows
|
||||
to ensure consistency of the sign of neighboring normals. The implementation
|
||||
follows the sign disabiguation from SHOT descriptors [1].
|
||||
|
||||
The algorithm also returns the curvature values themselves.
|
||||
These are the eigenvalues of the estimated covariance matrices
|
||||
of each point neighborhood.
|
||||
|
||||
Args:
|
||||
**pointclouds**: Batch of 3-dimensional points of shape
|
||||
`(minibatch, num_point, 3)` or a `Pointclouds` object.
|
||||
**neighborhood_size**: The size of the neighborhood used to estimate the
|
||||
geometry around each point.
|
||||
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
||||
ensure sign consistency of the normals of neigboring points.
|
||||
|
||||
Returns:
|
||||
**curvatures**: The three principal curvatures of each point
|
||||
of shape `(minibatch, num_point, 3)`.
|
||||
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
|
||||
**local_coord_frames**: The three principal directions of the curvature
|
||||
around each point of shape `(minibatch, num_point, 3, 3)`.
|
||||
The principal directions are stored in columns of the output.
|
||||
E.g. `local_coord_frames[i, j, :, 0]` is the normal of
|
||||
`j`-th point in the `i`-th pointcloud.
|
||||
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
|
||||
|
||||
References:
|
||||
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
|
||||
Local Surface Description, ECCV 2010.
|
||||
"""
|
||||
|
||||
points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)
|
||||
|
||||
ba, N, dim = points_padded.shape
|
||||
if dim != 3:
|
||||
raise ValueError(
|
||||
"The pointclouds argument has to be of shape (minibatch, N, 3)"
|
||||
)
|
||||
|
||||
if (num_points <= neighborhood_size).any():
|
||||
raise ValueError(
|
||||
"The neighborhood_size argument has to be"
|
||||
+ " >= size of each of the point clouds."
|
||||
)
|
||||
|
||||
# undo global mean for stability
|
||||
# TODO: replace with tutil.wmean once landed
|
||||
pcl_mean = points_padded.sum(1) / num_points[:, None]
|
||||
points_centered = points_padded - pcl_mean[:, None, :]
|
||||
|
||||
# get the per-point covariance and nearest neighbors used to compute it
|
||||
cov, knns = get_point_covariances(points_centered, num_points, neighborhood_size)
|
||||
|
||||
# get the local coord frames as principal directions of
|
||||
# the per-point covariance
|
||||
# this is done with torch.symeig, which returns the
|
||||
# eigenvectors (=principal directions) in an ascending order of their
|
||||
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
|
||||
# corresponds to the normal direction
|
||||
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
|
||||
|
||||
# disambiguate the directions of individual principal vectors
|
||||
if disambiguate_directions:
|
||||
# disambiguate normal
|
||||
n = _disambiguate_vector_directions(
|
||||
points_centered, knns, local_coord_frames[:, :, :, 0]
|
||||
)
|
||||
# disambiguate the main curvature
|
||||
z = _disambiguate_vector_directions(
|
||||
points_centered, knns, local_coord_frames[:, :, :, 2]
|
||||
)
|
||||
# the secondary curvature is just a cross between n and z
|
||||
y = torch.cross(n, z, dim=2)
|
||||
# cat to form the set of principal directions
|
||||
local_coord_frames = torch.stack((n, y, z), dim=3)
|
||||
|
||||
return curvatures, local_coord_frames
|
||||
|
||||
|
||||
def _disambiguate_vector_directions(pcl, knns, vecs):
|
||||
"""
|
||||
Disambiguates normal directions according to [1].
|
||||
|
||||
References:
|
||||
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
|
||||
Local Surface Description, ECCV 2010.
|
||||
"""
|
||||
# parse out K from the shape of knns
|
||||
K = knns.shape[2]
|
||||
# the difference between the mean of each neighborhood and
|
||||
# each element of the neighborhood
|
||||
df = knns - pcl[:, :, None]
|
||||
# projection of the difference on the principal direction
|
||||
proj = (vecs[:, :, None] * df).sum(3)
|
||||
# check how many projections are positive
|
||||
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)
|
||||
# flip the principal directions where number of positive correlations
|
||||
flip = (n_pos < (0.5 * K)).type_as(knns)
|
||||
vecs = (1.0 - 2.0 * flip) * vecs
|
||||
return vecs
|
@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .knn import knn_points
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pytorch3d.structures import Pointclouds
|
||||
@ -92,8 +94,53 @@ def convert_pointclouds_to_tensor(pcl: Union[torch.Tensor, "Pointclouds"]):
|
||||
|
||||
|
||||
def is_pointclouds(pcl: Union[torch.Tensor, "Pointclouds"]):
|
||||
""" Checks whether the input `pcl` is an instance `Pointclouds` of
|
||||
""" Checks whether the input `pcl` is an instance of `Pointclouds`
|
||||
by checking the existence of `points_padded` and `num_points_per_cloud`
|
||||
functions.
|
||||
"""
|
||||
return hasattr(pcl, "points_padded") and hasattr(pcl, "num_points_per_cloud")
|
||||
|
||||
|
||||
def get_point_covariances(
|
||||
points_padded: torch.Tensor,
|
||||
num_points_per_cloud: torch.Tensor,
|
||||
neighborhood_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the per-point covariance matrices by of the 3D locations of
|
||||
K-nearest neighbors of each point.
|
||||
|
||||
Args:
|
||||
**points_padded**: Input point clouds as a padded tensor
|
||||
of shape `(minibatch, num_points, dim)`.
|
||||
**num_points_per_cloud**: Number of points per cloud
|
||||
of shape `(minibatch,)`.
|
||||
**neighborhood_size**: Number of nearest neighbors for each point
|
||||
used to estimate the covariance matrices.
|
||||
|
||||
Returns:
|
||||
**covariances**: A batch of per-point covariance matrices
|
||||
of shape `(minibatch, dim, dim)`.
|
||||
**k_nearest_neighbors**: A batch of `neighborhood_size` nearest
|
||||
neighbors for each of the point cloud points
|
||||
of shape `(minibatch, num_points, neighborhood_size, dim)`.
|
||||
"""
|
||||
# get K nearest neighbor idx for each point in the point cloud
|
||||
_, _, k_nearest_neighbors = knn_points(
|
||||
points_padded,
|
||||
points_padded,
|
||||
lengths1=num_points_per_cloud,
|
||||
lengths2=num_points_per_cloud,
|
||||
K=neighborhood_size,
|
||||
return_nn=True,
|
||||
)
|
||||
# obtain the mean of the neighborhood
|
||||
pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
|
||||
# compute the diff of the neighborhood and the mean of the neighborhood
|
||||
central_diff = k_nearest_neighbors - pt_mean
|
||||
# per-nn-point covariances
|
||||
per_pt_cov = central_diff.unsqueeze(4) * central_diff.unsqueeze(3)
|
||||
# per-point covariances
|
||||
covariances = per_pt_cov.mean(2)
|
||||
|
||||
return covariances, k_nearest_neighbors
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from .. import ops
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
@ -847,6 +848,51 @@ class Pointclouds(object):
|
||||
bboxes = torch.stack([all_mins, all_maxes], dim=2)
|
||||
return bboxes
|
||||
|
||||
def estimate_normals(
|
||||
self,
|
||||
neighborhood_size: int = 50,
|
||||
disambiguate_directions: bool = True,
|
||||
assign_to_self: bool = False,
|
||||
):
|
||||
"""
|
||||
Estimates the normals of each point in each cloud and assigns
|
||||
them to the internal tensors `self._normals_list` and `self._normals_padded`
|
||||
|
||||
The function uses `ops.estimate_pointcloud_local_coord_frames`
|
||||
to estimate the normals. Please refer to this function for more
|
||||
detailed information about the implemented algorithm.
|
||||
|
||||
Args:
|
||||
**neighborhood_size**: The size of the neighborhood used to estimate the
|
||||
geometry around each point.
|
||||
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
||||
ensure sign consistency of the normals of neigboring points.
|
||||
**normals**: A tensor of normals for each input point
|
||||
of shape `(minibatch, num_point, 3)`.
|
||||
If `pointclouds` are of `Pointclouds` class, returns a padded tensor.
|
||||
**assign_to_self**: If `True`, assigns the computed normals to the
|
||||
internal buffers overwriting any previously stored normals.
|
||||
|
||||
References:
|
||||
[1] Tombari, Salti, Di Stefano: Unique Signatures of Histograms for
|
||||
Local Surface Description, ECCV 2010.
|
||||
"""
|
||||
|
||||
# estimate the normals
|
||||
normals_est = ops.estimate_pointcloud_normals(
|
||||
self,
|
||||
neighborhood_size=neighborhood_size,
|
||||
disambiguate_directions=disambiguate_directions,
|
||||
)
|
||||
|
||||
# assign to self
|
||||
if assign_to_self:
|
||||
self._normals_list, self._normals_padded, _ = self._parse_auxiliary_input(
|
||||
normals_est
|
||||
)
|
||||
|
||||
return normals_est
|
||||
|
||||
def extend(self, N: int):
|
||||
"""
|
||||
Create new Pointclouds which contains each cloud N times.
|
||||
|
@ -185,6 +185,18 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(expected_verts, actual_verts)
|
||||
self.assertClose(expected_faces, actual_faces)
|
||||
|
||||
def test_normals_save(self):
|
||||
verts = torch.tensor(
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
|
||||
normals = torch.tensor(
|
||||
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float32
|
||||
)
|
||||
file = StringIO()
|
||||
save_ply(file, verts=verts, faces=faces, verts_normals=normals)
|
||||
file.close()
|
||||
|
||||
def test_empty_save_load(self):
|
||||
# Vertices + empty faces
|
||||
verts = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
|
152
tests/test_points_normals.py
Normal file
152
tests/test_points_normals.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import (
|
||||
estimate_pointcloud_local_coord_frames,
|
||||
estimate_pointcloud_normals,
|
||||
)
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class TestPCLNormals(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
@staticmethod
|
||||
def init_spherical_pcl(
|
||||
batch_size=3, num_points=3000, device=None, use_pointclouds=False
|
||||
) -> Tuple[Union[torch.Tensor, Pointclouds], torch.Tensor]:
|
||||
# random spherical point cloud
|
||||
pcl = torch.randn(
|
||||
(batch_size, num_points, 3), device=device, dtype=torch.float32
|
||||
)
|
||||
pcl = torch.nn.functional.normalize(pcl, dim=2)
|
||||
|
||||
# GT normals are the same as
|
||||
# the location of each point on the 0-centered sphere
|
||||
normals = pcl.clone()
|
||||
|
||||
# scale and offset the sphere randomly
|
||||
pcl *= torch.rand(batch_size, 1, 1).type_as(pcl) + 1.0
|
||||
pcl += torch.randn(batch_size, 1, 3).type_as(pcl)
|
||||
|
||||
if use_pointclouds:
|
||||
num_points = torch.randint(
|
||||
size=(batch_size,), low=int(num_points * 0.7), high=num_points
|
||||
)
|
||||
pcl, normals = [
|
||||
[x[:np] for x, np in zip(X, num_points)] for X in (pcl, normals)
|
||||
]
|
||||
pcl = Pointclouds(pcl, normals=normals)
|
||||
|
||||
return pcl, normals
|
||||
|
||||
def test_pcl_normals(self, batch_size=3, num_points=300, neighborhood_size=50):
|
||||
"""
|
||||
Tests the normal estimation on a spherical point cloud, where
|
||||
we know the ground truth normals.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
# run several times for different random point clouds
|
||||
for run_idx in range(3):
|
||||
# either use tensors or Pointclouds as input
|
||||
for use_pointclouds in (True, False):
|
||||
# get a spherical point cloud
|
||||
pcl, normals_gt = TestPCLNormals.init_spherical_pcl(
|
||||
num_points=num_points,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
use_pointclouds=use_pointclouds,
|
||||
)
|
||||
if use_pointclouds:
|
||||
normals_gt = pcl.normals_padded()
|
||||
num_pcl_points = pcl.num_points_per_cloud()
|
||||
else:
|
||||
num_pcl_points = [pcl.shape[1]] * batch_size
|
||||
|
||||
# check for both disambiguation options
|
||||
for disambiguate_directions in (True, False):
|
||||
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
|
||||
pcl,
|
||||
neighborhood_size=neighborhood_size,
|
||||
disambiguate_directions=disambiguate_directions,
|
||||
)
|
||||
|
||||
# estimate the normals
|
||||
normals = estimate_pointcloud_normals(
|
||||
pcl,
|
||||
neighborhood_size=neighborhood_size,
|
||||
disambiguate_directions=disambiguate_directions,
|
||||
)
|
||||
|
||||
# TODO: temporarily disabled
|
||||
if use_pointclouds:
|
||||
# test that the class method gives the same output
|
||||
normals_pcl = pcl.estimate_normals(
|
||||
neighborhood_size=neighborhood_size,
|
||||
disambiguate_directions=disambiguate_directions,
|
||||
assign_to_self=True,
|
||||
)
|
||||
normals_from_pcl = pcl.normals_padded()
|
||||
for nrm, nrm_from_pcl, nrm_pcl, np in zip(
|
||||
normals, normals_from_pcl, normals_pcl, num_pcl_points
|
||||
):
|
||||
self.assertClose(nrm[:np], nrm_pcl[:np], atol=1e-5)
|
||||
self.assertClose(nrm[:np], nrm_from_pcl[:np], atol=1e-5)
|
||||
|
||||
# check that local coord frames give the same normal
|
||||
# as normals
|
||||
for nrm, lcoord, np in zip(
|
||||
normals, local_coord_frames, num_pcl_points
|
||||
):
|
||||
self.assertClose(nrm[:np], lcoord[:np, :, 0], atol=1e-5)
|
||||
|
||||
# dotp between normals and normals_gt
|
||||
normal_parallel = (normals_gt * normals).sum(2)
|
||||
|
||||
# check that normals are on average
|
||||
# parallel to the expected ones
|
||||
for normp, np in zip(normal_parallel, num_pcl_points):
|
||||
abs_parallel = normp[:np].abs()
|
||||
avg_parallel = abs_parallel.mean()
|
||||
std_parallel = abs_parallel.std()
|
||||
self.assertClose(
|
||||
avg_parallel, torch.ones_like(avg_parallel), atol=1e-2
|
||||
)
|
||||
self.assertClose(
|
||||
std_parallel, torch.zeros_like(std_parallel), atol=1e-2
|
||||
)
|
||||
|
||||
if disambiguate_directions:
|
||||
# check that 95% of normal dot products
|
||||
# have the same sign
|
||||
for normp, np in zip(normal_parallel, num_pcl_points):
|
||||
n_pos = (normp[:np] > 0).sum()
|
||||
self.assertTrue((n_pos > np * 0.95) or (n_pos < np * 0.05))
|
||||
|
||||
if DEBUG and run_idx == 0 and not use_pointclouds:
|
||||
import os
|
||||
from pytorch3d.io.ply_io import save_ply
|
||||
|
||||
# export to .ply
|
||||
outdir = "/tmp/pt3d_pcl_normals_test/"
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
plyfile = os.path.join(
|
||||
outdir, f"pcl_disamb={disambiguate_directions}.ply"
|
||||
)
|
||||
print(f"Storing point cloud with normals to {plyfile}.")
|
||||
pcl_idx = 0
|
||||
save_ply(
|
||||
plyfile,
|
||||
pcl[pcl_idx].cpu(),
|
||||
faces=None,
|
||||
verts_normals=normals[pcl_idx].cpu(),
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user