mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
taubin smoothing
Summary: Taubin Smoothing for filtering meshes and making them smoother. Taubin smoothing is an iterative approach. Reviewed By: nikhilaravi Differential Revision: D24751149 fbshipit-source-id: fb779e955f1a1f6750e704f1b4c6dfa37aebac1a
This commit is contained in:
parent
fc7a4cacc3
commit
112959e087
@ -23,6 +23,7 @@ from .utils import (
|
|||||||
wmean,
|
wmean,
|
||||||
)
|
)
|
||||||
from .vert_align import vert_align
|
from .vert_align import vert_align
|
||||||
|
from .mesh_filtering import taubin_smoothing
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
83
pytorch3d/ops/mesh_filtering.py
Normal file
83
pytorch3d/ops/mesh_filtering.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import torch
|
||||||
|
from pytorch3d.structures import Meshes, utils as struct_utils
|
||||||
|
|
||||||
|
# ------------------------ Mesh Smoothing ------------------------ #
|
||||||
|
# This file contains differentiable operators to filter meshes
|
||||||
|
# The ops include
|
||||||
|
# 1) Taubin Smoothing
|
||||||
|
# TODO(gkioxari) add more! :)
|
||||||
|
# ---------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------- Taubin Smoothing ----------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12):
|
||||||
|
"""
|
||||||
|
Norm laplacian computes a variant of the laplacian matrix which weights each
|
||||||
|
affinity with the normalized distance of the neighboring nodes.
|
||||||
|
More concretely,
|
||||||
|
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verts: tensor of shape (V, 3) containing the vertices of the graph
|
||||||
|
edges: tensor of shape (E, 2) containing the vertex indices of each edge
|
||||||
|
"""
|
||||||
|
edge_verts = verts[edges] # (E, 2, 3)
|
||||||
|
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
|
||||||
|
|
||||||
|
# Side lengths of each edge, of shape (E,)
|
||||||
|
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
|
||||||
|
|
||||||
|
# Construct a sparse matrix by basically doing:
|
||||||
|
# L[v0, v1] = w01
|
||||||
|
# L[v1, v0] = w01
|
||||||
|
e01 = edges.t() # (2, E)
|
||||||
|
|
||||||
|
V = verts.shape[0]
|
||||||
|
L = torch.sparse.FloatTensor(e01, w01, (V, V))
|
||||||
|
L = L + L.t()
|
||||||
|
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
def taubin_smoothing(
|
||||||
|
meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10
|
||||||
|
) -> Meshes:
|
||||||
|
"""
|
||||||
|
Taubin smoothing [1] is an iterative smoothing operator for meshes.
|
||||||
|
At each iteration
|
||||||
|
verts := (1 - λ) * verts + λ * L * verts
|
||||||
|
verts := (1 - μ) * verts + μ * L * verts
|
||||||
|
|
||||||
|
This function returns a new mesh with smoothed vertices.
|
||||||
|
Args:
|
||||||
|
meshes: Meshes input to be smoothed
|
||||||
|
lambd, mu: float parameters for Taubin smoothing,
|
||||||
|
lambd > 0, mu < 0
|
||||||
|
num_iter: number of iterations to execute smoothing
|
||||||
|
Returns:
|
||||||
|
mesh: Smoothed input Meshes
|
||||||
|
|
||||||
|
[1] Curve and Surface Smoothing without Shrinkage,
|
||||||
|
Gabriel Taubin, ICCV 1997
|
||||||
|
"""
|
||||||
|
verts = meshes.verts_packed() # V x 3
|
||||||
|
edges = meshes.edges_packed() # E x 3
|
||||||
|
|
||||||
|
for _ in range(num_iter):
|
||||||
|
L = norm_laplacian(verts, edges)
|
||||||
|
total_weight = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
|
||||||
|
verts = (1 - lambd) * verts + lambd * torch.mm(L, verts) / total_weight
|
||||||
|
|
||||||
|
# pyre-ignore
|
||||||
|
L = norm_laplacian(verts, edges)
|
||||||
|
total_weight = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
|
||||||
|
verts = (1 - mu) * verts + mu * torch.mm(L, verts) / total_weight
|
||||||
|
|
||||||
|
verts_list = struct_utils.packed_to_list(
|
||||||
|
verts, meshes.num_verts_per_mesh().tolist()
|
||||||
|
)
|
||||||
|
mesh = Meshes(verts=list(verts_list), faces=meshes.faces_list())
|
||||||
|
return mesh
|
74
tests/test_mesh_filtering.py
Normal file
74
tests/test_mesh_filtering.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||||
|
from pytorch3d.ops import taubin_smoothing
|
||||||
|
from pytorch3d.ops.mesh_filtering import norm_laplacian
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from pytorch3d.utils import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaubinSmoothing(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(1)
|
||||||
|
|
||||||
|
def test_taubin(self):
|
||||||
|
N = 3
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
|
||||||
|
mesh = ico_sphere(4, device).extend(N)
|
||||||
|
ico_verts = mesh.verts_padded()
|
||||||
|
ico_faces = mesh.faces_padded()
|
||||||
|
|
||||||
|
rand_noise = torch.rand_like(ico_verts) * 0.2 - 0.1
|
||||||
|
z_mask = (ico_verts[:, :, -1] > 0).view(N, -1, 1)
|
||||||
|
rand_noise = rand_noise * z_mask
|
||||||
|
verts = ico_verts + rand_noise
|
||||||
|
mesh = Meshes(verts=verts, faces=ico_faces)
|
||||||
|
|
||||||
|
smooth_mesh = taubin_smoothing(mesh, num_iter=50)
|
||||||
|
smooth_verts = smooth_mesh.verts_padded()
|
||||||
|
|
||||||
|
smooth_dist = (smooth_verts - ico_verts).norm(dim=-1).mean()
|
||||||
|
dist = (verts - ico_verts).norm(dim=-1).mean()
|
||||||
|
self.assertTrue(smooth_dist < dist)
|
||||||
|
|
||||||
|
def test_norm_laplacian(self):
|
||||||
|
V = 32
|
||||||
|
F = 64
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
# random vertices
|
||||||
|
verts = torch.rand((V, 3), dtype=torch.float32, device=device)
|
||||||
|
# random valid faces (no self circles, e.g. (v0, v0, v1))
|
||||||
|
faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3]
|
||||||
|
faces = faces.to(device=device)
|
||||||
|
mesh = Meshes(verts=[verts], faces=[faces])
|
||||||
|
edges = mesh.edges_packed()
|
||||||
|
|
||||||
|
eps = 1e-12
|
||||||
|
|
||||||
|
L = norm_laplacian(verts, edges, eps=eps)
|
||||||
|
|
||||||
|
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=device)
|
||||||
|
for f in range(F):
|
||||||
|
f0, f1, f2 = faces[f]
|
||||||
|
v0 = verts[f0]
|
||||||
|
v1 = verts[f1]
|
||||||
|
v2 = verts[f2]
|
||||||
|
|
||||||
|
w12 = 1.0 / ((v1 - v2).norm() + eps)
|
||||||
|
w02 = 1.0 / ((v0 - v2).norm() + eps)
|
||||||
|
w01 = 1.0 / ((v0 - v1).norm() + eps)
|
||||||
|
|
||||||
|
Lnaive[f0, f1] = w01
|
||||||
|
Lnaive[f1, f0] = w01
|
||||||
|
Lnaive[f0, f2] = w02
|
||||||
|
Lnaive[f2, f0] = w02
|
||||||
|
Lnaive[f1, f2] = w12
|
||||||
|
Lnaive[f2, f1] = w12
|
||||||
|
|
||||||
|
self.assertClose(L.to_dense(), Lnaive)
|
Loading…
x
Reference in New Issue
Block a user