# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest from typing import Tuple, Union import torch from pytorch3d.ops import ( estimate_pointcloud_local_coord_frames, estimate_pointcloud_normals, ) from pytorch3d.structures.pointclouds import Pointclouds from .common_testing import TestCaseMixin DEBUG = False class TestPCLNormals(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() torch.manual_seed(42) @staticmethod def init_spherical_pcl( batch_size=3, num_points=3000, device=None, use_pointclouds=False ) -> Tuple[Union[torch.Tensor, Pointclouds], torch.Tensor]: # random spherical point cloud pcl = torch.randn( (batch_size, num_points, 3), device=device, dtype=torch.float32 ) pcl = torch.nn.functional.normalize(pcl, dim=2) # GT normals are the same as # the location of each point on the 0-centered sphere normals = pcl.clone() # scale and offset the sphere randomly pcl *= torch.rand(batch_size, 1, 1).type_as(pcl) + 1.0 pcl += torch.randn(batch_size, 1, 3).type_as(pcl) if use_pointclouds: num_points = torch.randint( size=(batch_size,), low=int(num_points * 0.7), high=num_points ) pcl, normals = [ [x[:np] for x, np in zip(X, num_points)] for X in (pcl, normals) ] pcl = Pointclouds(pcl, normals=normals) return pcl, normals def test_pcl_normals(self, batch_size=3, num_points=300, neighborhood_size=50): """ Tests the normal estimation on a spherical point cloud, where we know the ground truth normals. """ device = torch.device("cuda:0") # run several times for different random point clouds for run_idx in range(3): # either use tensors or Pointclouds as input for use_pointclouds in (True, False): # get a spherical point cloud pcl, normals_gt = TestPCLNormals.init_spherical_pcl( num_points=num_points, batch_size=batch_size, device=device, use_pointclouds=use_pointclouds, ) if use_pointclouds: normals_gt = pcl.normals_padded() num_pcl_points = pcl.num_points_per_cloud() else: num_pcl_points = [pcl.shape[1]] * batch_size # check for both disambiguation options for disambiguate_directions in (True, False): ( curvatures, local_coord_frames, ) = estimate_pointcloud_local_coord_frames( pcl, neighborhood_size=neighborhood_size, disambiguate_directions=disambiguate_directions, ) # estimate the normals normals = estimate_pointcloud_normals( pcl, neighborhood_size=neighborhood_size, disambiguate_directions=disambiguate_directions, ) # TODO: temporarily disabled if use_pointclouds: # test that the class method gives the same output normals_pcl = pcl.estimate_normals( neighborhood_size=neighborhood_size, disambiguate_directions=disambiguate_directions, assign_to_self=True, ) normals_from_pcl = pcl.normals_padded() for nrm, nrm_from_pcl, nrm_pcl, np in zip( normals, normals_from_pcl, normals_pcl, num_pcl_points ): self.assertClose(nrm[:np], nrm_pcl[:np], atol=1e-5) self.assertClose(nrm[:np], nrm_from_pcl[:np], atol=1e-5) # check that local coord frames give the same normal # as normals for nrm, lcoord, np in zip( normals, local_coord_frames, num_pcl_points ): self.assertClose(nrm[:np], lcoord[:np, :, 0], atol=1e-5) # dotp between normals and normals_gt normal_parallel = (normals_gt * normals).sum(2) # check that normals are on average # parallel to the expected ones for normp, np in zip(normal_parallel, num_pcl_points): abs_parallel = normp[:np].abs() avg_parallel = abs_parallel.mean() std_parallel = abs_parallel.std() self.assertClose( avg_parallel, torch.ones_like(avg_parallel), atol=1e-2 ) self.assertClose( std_parallel, torch.zeros_like(std_parallel), atol=1e-2 ) if disambiguate_directions: # check that 95% of normal dot products # have the same sign for normp, np in zip(normal_parallel, num_pcl_points): n_pos = (normp[:np] > 0).sum() self.assertTrue((n_pos > np * 0.95) or (n_pos < np * 0.05)) if DEBUG and run_idx == 0 and not use_pointclouds: import os from pytorch3d.io.ply_io import save_ply # export to .ply outdir = "/tmp/pt3d_pcl_normals_test/" os.makedirs(outdir, exist_ok=True) plyfile = os.path.join( outdir, f"pcl_disamb={disambiguate_directions}.ply" ) print(f"Storing point cloud with normals to {plyfile}.") pcl_idx = 0 save_ply( plyfile, pcl[pcl_idx].cpu(), faces=None, verts_normals=normals[pcl_idx].cpu(), )