mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
use workaround for points_normals
Summary: Use existing workaround for batched 3x3 symeig because it is faster than torch.symeig. Added benchmark showing speedup. True = workaround. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- normals_True_3000 16237 17233 31 normals_True_6000 33028 33391 16 normals_False_3000 18623069 18623069 1 normals_False_6000 36535475 36535475 1 ``` Should help https://github.com/facebookresearch/pytorch3d/issues/988 Reviewed By: nikhilaravi Differential Revision: D33660585 fbshipit-source-id: d1162b277f5d61ed67e367057a61f25e03888dce
This commit is contained in:
parent
5053142363
commit
c2862ff427
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..common.workaround import symeig3x3
|
||||||
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
from .utils import convert_pointclouds_to_tensor, get_point_covariances
|
||||||
|
|
||||||
|
|
||||||
@ -19,6 +20,8 @@ def estimate_pointcloud_normals(
|
|||||||
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
||||||
neighborhood_size: int = 50,
|
neighborhood_size: int = 50,
|
||||||
disambiguate_directions: bool = True,
|
disambiguate_directions: bool = True,
|
||||||
|
*,
|
||||||
|
use_symeig_workaround: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Estimates the normals of a batch of `pointclouds`.
|
Estimates the normals of a batch of `pointclouds`.
|
||||||
@ -33,6 +36,8 @@ def estimate_pointcloud_normals(
|
|||||||
geometry around each point.
|
geometry around each point.
|
||||||
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
||||||
ensure sign consistency of the normals of neighboring points.
|
ensure sign consistency of the normals of neighboring points.
|
||||||
|
**use_symeig_workaround**: If `True`, uses a custom eigenvalue
|
||||||
|
calculation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
**normals**: A tensor of normals for each input point
|
**normals**: A tensor of normals for each input point
|
||||||
@ -48,6 +53,7 @@ def estimate_pointcloud_normals(
|
|||||||
pointclouds,
|
pointclouds,
|
||||||
neighborhood_size=neighborhood_size,
|
neighborhood_size=neighborhood_size,
|
||||||
disambiguate_directions=disambiguate_directions,
|
disambiguate_directions=disambiguate_directions,
|
||||||
|
use_symeig_workaround=use_symeig_workaround,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the normals correspond to the first vector of each local coord frame
|
# the normals correspond to the first vector of each local coord frame
|
||||||
@ -60,6 +66,8 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
pointclouds: Union[torch.Tensor, "Pointclouds"],
|
||||||
neighborhood_size: int = 50,
|
neighborhood_size: int = 50,
|
||||||
disambiguate_directions: bool = True,
|
disambiguate_directions: bool = True,
|
||||||
|
*,
|
||||||
|
use_symeig_workaround: bool = True,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Estimates the principal directions of curvature (which includes normals)
|
Estimates the principal directions of curvature (which includes normals)
|
||||||
@ -88,6 +96,8 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
geometry around each point.
|
geometry around each point.
|
||||||
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
**disambiguate_directions**: If `True`, uses the algorithm from [1] to
|
||||||
ensure sign consistency of the normals of neighboring points.
|
ensure sign consistency of the normals of neighboring points.
|
||||||
|
**use_symeig_workaround**: If `True`, uses a custom eigenvalue
|
||||||
|
calculation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
**curvatures**: The three principal curvatures of each point
|
**curvatures**: The three principal curvatures of each point
|
||||||
@ -133,6 +143,9 @@ def estimate_pointcloud_local_coord_frames(
|
|||||||
# eigenvectors (=principal directions) in an ascending order of their
|
# eigenvectors (=principal directions) in an ascending order of their
|
||||||
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
|
# corresponding eigenvalues, while the smallest eigenvalue's eigenvector
|
||||||
# corresponds to the normal direction
|
# corresponds to the normal direction
|
||||||
|
if use_symeig_workaround:
|
||||||
|
curvatures, local_coord_frames = symeig3x3(cov, eigenvectors=True)
|
||||||
|
else:
|
||||||
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
|
curvatures, local_coord_frames = torch.symeig(cov, eigenvectors=True)
|
||||||
|
|
||||||
# disambiguate the directions of individual principal vectors
|
# disambiguate the directions of individual principal vectors
|
||||||
|
47
tests/benchmarks/bm_points_normals.py
Normal file
47
tests/benchmarks/bm_points_normals.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from pytorch3d.ops import estimate_pointcloud_normals
|
||||||
|
from test_points_normals import TestPCLNormals
|
||||||
|
|
||||||
|
|
||||||
|
def to_bm(num_points, use_symeig_workaround):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
points_padded, _normals = TestPCLNormals.init_spherical_pcl(
|
||||||
|
num_points=num_points, device=device, use_pointclouds=False
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def run():
|
||||||
|
estimate_pointcloud_normals(
|
||||||
|
points_padded, use_symeig_workaround=use_symeig_workaround
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
def bm_points_normals() -> None:
|
||||||
|
case_grid = {
|
||||||
|
"use_symeig_workaround": [True, False],
|
||||||
|
"num_points": [3000, 6000],
|
||||||
|
}
|
||||||
|
test_cases = itertools.product(*case_grid.values())
|
||||||
|
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||||
|
benchmark(
|
||||||
|
to_bm,
|
||||||
|
"normals",
|
||||||
|
kwargs_list,
|
||||||
|
warmup_iters=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bm_points_normals()
|
Loading…
x
Reference in New Issue
Block a user