diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 4efdd0ea..01e26307 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -23,6 +23,7 @@ from .utils import ( wmean, ) from .vert_align import vert_align +from .mesh_filtering import taubin_smoothing __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/ops/mesh_filtering.py b/pytorch3d/ops/mesh_filtering.py new file mode 100644 index 00000000..a65bc542 --- /dev/null +++ b/pytorch3d/ops/mesh_filtering.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import torch +from pytorch3d.structures import Meshes, utils as struct_utils + +# ------------------------ Mesh Smoothing ------------------------ # +# This file contains differentiable operators to filter meshes +# The ops include +# 1) Taubin Smoothing +# TODO(gkioxari) add more! :) +# ---------------------------------------------------------------- # + + +# ----------------------- Taubin Smoothing ----------------------- # + + +def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12): + """ + Norm laplacian computes a variant of the laplacian matrix which weights each + affinity with the normalized distance of the neighboring nodes. + More concretely, + L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes + + Args: + verts: tensor of shape (V, 3) containing the vertices of the graph + edges: tensor of shape (E, 2) containing the vertex indices of each edge + """ + edge_verts = verts[edges] # (E, 2, 3) + v0, v1 = edge_verts[:, 0], edge_verts[:, 1] + + # Side lengths of each edge, of shape (E,) + w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps) + + # Construct a sparse matrix by basically doing: + # L[v0, v1] = w01 + # L[v1, v0] = w01 + e01 = edges.t() # (2, E) + + V = verts.shape[0] + L = torch.sparse.FloatTensor(e01, w01, (V, V)) + L = L + L.t() + + return L + + +def taubin_smoothing( + meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10 +) -> Meshes: + """ + Taubin smoothing [1] is an iterative smoothing operator for meshes. + At each iteration + verts := (1 - λ) * verts + λ * L * verts + verts := (1 - μ) * verts + μ * L * verts + + This function returns a new mesh with smoothed vertices. + Args: + meshes: Meshes input to be smoothed + lambd, mu: float parameters for Taubin smoothing, + lambd > 0, mu < 0 + num_iter: number of iterations to execute smoothing + Returns: + mesh: Smoothed input Meshes + + [1] Curve and Surface Smoothing without Shrinkage, + Gabriel Taubin, ICCV 1997 + """ + verts = meshes.verts_packed() # V x 3 + edges = meshes.edges_packed() # E x 3 + + for _ in range(num_iter): + L = norm_laplacian(verts, edges) + total_weight = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) + verts = (1 - lambd) * verts + lambd * torch.mm(L, verts) / total_weight + + # pyre-ignore + L = norm_laplacian(verts, edges) + total_weight = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) + verts = (1 - mu) * verts + mu * torch.mm(L, verts) / total_weight + + verts_list = struct_utils.packed_to_list( + verts, meshes.num_verts_per_mesh().tolist() + ) + mesh = Meshes(verts=list(verts_list), faces=meshes.faces_list()) + return mesh diff --git a/tests/test_mesh_filtering.py b/tests/test_mesh_filtering.py new file mode 100644 index 00000000..098d44f3 --- /dev/null +++ b/tests/test_mesh_filtering.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.ops import taubin_smoothing +from pytorch3d.ops.mesh_filtering import norm_laplacian +from pytorch3d.structures import Meshes +from pytorch3d.utils import ico_sphere + + +class TestTaubinSmoothing(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + def test_taubin(self): + N = 3 + device = get_random_cuda_device() + + mesh = ico_sphere(4, device).extend(N) + ico_verts = mesh.verts_padded() + ico_faces = mesh.faces_padded() + + rand_noise = torch.rand_like(ico_verts) * 0.2 - 0.1 + z_mask = (ico_verts[:, :, -1] > 0).view(N, -1, 1) + rand_noise = rand_noise * z_mask + verts = ico_verts + rand_noise + mesh = Meshes(verts=verts, faces=ico_faces) + + smooth_mesh = taubin_smoothing(mesh, num_iter=50) + smooth_verts = smooth_mesh.verts_padded() + + smooth_dist = (smooth_verts - ico_verts).norm(dim=-1).mean() + dist = (verts - ico_verts).norm(dim=-1).mean() + self.assertTrue(smooth_dist < dist) + + def test_norm_laplacian(self): + V = 32 + F = 64 + device = get_random_cuda_device() + # random vertices + verts = torch.rand((V, 3), dtype=torch.float32, device=device) + # random valid faces (no self circles, e.g. (v0, v0, v1)) + faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3] + faces = faces.to(device=device) + mesh = Meshes(verts=[verts], faces=[faces]) + edges = mesh.edges_packed() + + eps = 1e-12 + + L = norm_laplacian(verts, edges, eps=eps) + + Lnaive = torch.zeros((V, V), dtype=torch.float32, device=device) + for f in range(F): + f0, f1, f2 = faces[f] + v0 = verts[f0] + v1 = verts[f1] + v2 = verts[f2] + + w12 = 1.0 / ((v1 - v2).norm() + eps) + w02 = 1.0 / ((v0 - v2).norm() + eps) + w01 = 1.0 / ((v0 - v1).norm() + eps) + + Lnaive[f0, f1] = w01 + Lnaive[f1, f0] = w01 + Lnaive[f0, f2] = w02 + Lnaive[f2, f0] = w02 + Lnaive[f1, f2] = w12 + Lnaive[f2, f1] = w12 + + self.assertClose(L.to_dense(), Lnaive)