diff --git a/pytorch3d/loss/mesh_laplacian_smoothing.py b/pytorch3d/loss/mesh_laplacian_smoothing.py index e86ddfce..d2f6dc79 100644 --- a/pytorch3d/loss/mesh_laplacian_smoothing.py +++ b/pytorch3d/loss/mesh_laplacian_smoothing.py @@ -6,6 +6,7 @@ import torch +from pytorch3d.ops import cot_laplacian def mesh_laplacian_smoothing(meshes, method: str = "uniform"): @@ -94,6 +95,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"): N = len(meshes) verts_packed = meshes.verts_packed() # (sum(V_n), 3) + faces_packed = meshes.faces_packed() # (sum(F_n), 3) num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,) verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),) weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),) @@ -106,7 +108,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"): if method == "uniform": L = meshes.laplacian_packed() elif method in ["cot", "cotcurv"]: - L, inv_areas = laplacian_cot(meshes) + L, inv_areas = cot_laplacian(verts_packed, faces_packed) if method == "cot": norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1) idx = norm_w > 0 @@ -127,73 +129,3 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"): loss = loss * weights return loss.sum() / N - - -def laplacian_cot(meshes): - """ - Returns the Laplacian matrix with cotangent weights and the inverse of the - face areas. - - Args: - meshes: Meshes object with a batch of meshes. - Returns: - 2-element tuple containing - - **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n)) - Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes. - See the description above for more clarity. - - **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of - face areas containing each vertex - """ - verts_packed = meshes.verts_packed() # (sum(V_n), 3) - faces_packed = meshes.faces_packed() # (sum(F_n), 3) - # V = sum(V_n), F = sum(F_n) - V, F = verts_packed.shape[0], faces_packed.shape[0] - - face_verts = verts_packed[faces_packed] - v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] - - # Side lengths of each triangle, of shape (sum(F_n),) - # A is the side opposite v1, B is opposite v2, and C is opposite v3 - A = (v1 - v2).norm(dim=1) - B = (v0 - v2).norm(dim=1) - C = (v0 - v1).norm(dim=1) - - # Area of each triangle (with Heron's formula); shape is (sum(F_n),) - s = 0.5 * (A + B + C) - # note that the area can be negative (close to 0) causing nans after sqrt() - # we clip it to a small positive value - area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() - - # Compute cotangents of angles, of shape (sum(F_n), 3) - A2, B2, C2 = A * A, B * B, C * C - cota = (B2 + C2 - A2) / area - cotb = (A2 + C2 - B2) / area - cotc = (A2 + B2 - C2) / area - cot = torch.stack([cota, cotb, cotc], dim=1) - cot /= 4.0 - - # Construct a sparse matrix by basically doing: - # L[v1, v2] = cota - # L[v2, v0] = cotb - # L[v0, v1] = cotc - ii = faces_packed[:, [1, 2, 0]] - jj = faces_packed[:, [2, 0, 1]] - idx = torch.stack([ii, jj], dim=0).view(2, F * 3) - L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) - - # Make it symmetric; this means we are also setting - # L[v2, v1] = cota - # L[v0, v2] = cotb - # L[v1, v0] = cotc - L += L.t() - - # For each vertex, compute the sum of areas for triangles containing it. - idx = faces_packed.view(-1) - inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device) - val = torch.stack([area] * 3, dim=1).view(-1) - inv_areas.scatter_add_(0, idx, val) - idx = inv_areas > 0 - inv_areas[idx] = 1.0 / inv_areas[idx] - inv_areas = inv_areas.view(-1, 1) - - return L, inv_areas diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index a9bf1196..80aa09a7 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -9,6 +9,7 @@ from .cubify import cubify from .graph_conv import GraphConv from .interp_face_attrs import interpolate_face_attributes from .knn import knn_gather, knn_points +from .laplacian_matrices import laplacian, cot_laplacian, norm_laplacian from .mesh_face_areas_normals import mesh_face_areas_normals from .mesh_filtering import taubin_smoothing from .packed_to_padded import packed_to_padded, padded_to_packed diff --git a/pytorch3d/ops/laplacian_matrices.py b/pytorch3d/ops/laplacian_matrices.py new file mode 100644 index 00000000..c1b5b391 --- /dev/null +++ b/pytorch3d/ops/laplacian_matrices.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +# ------------------------ Laplacian Matrices ------------------------ # +# This file contains implementations of differentiable laplacian matrices. +# These include +# 1) Standard Laplacian matrix +# 2) Cotangent Laplacian matrix +# 3) Norm Laplacian matrix +# -------------------------------------------------------------------- # + + +def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: + """ + Computes the laplacian matrix. + The definition of the laplacian is + L[i, j] = -1 , if i == j + L[i, j] = 1 / deg(i) , if (i, j) is an edge + L[i, j] = 0 , otherwise + where deg(i) is the degree of the i-th vertex in the graph. + + Args: + verts: tensor of shape (V, 3) containing the vertices of the graph + edges: tensor of shape (E, 2) containing the vertex indices of each edge + Returns: + L: Sparse FloatTensor of shape (V, V) + """ + V = verts.shape[0] + + e0, e1 = edges.unbind(1) + + idx01 = torch.stack([e0, e1], dim=1) # (E, 2) + idx10 = torch.stack([e1, e0], dim=1) # (E, 2) + idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*E) + + # First, we construct the adjacency matrix, + # i.e. A[i, j] = 1 if (i,j) is an edge, or + # A[e0, e1] = 1 & A[e1, e0] = 1 + ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device) + A = torch.sparse.FloatTensor(idx, ones, (V, V)) + + # the sum of i-th row of A gives the degree of the i-th vertex + deg = torch.sparse.sum(A, dim=1).to_dense() + + # We construct the Laplacian matrix by adding the non diagonal values + # i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge + deg0 = deg[e0] + deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0) + deg1 = deg[e1] + deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1) + val = torch.cat([deg0, deg1]) + L = torch.sparse.FloatTensor(idx, val, (V, V)) + + # Then we add the diagonal values L[i, i] = -1. + idx = torch.arange(V, device=verts.device) + idx = torch.stack([idx, idx], dim=0) + ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device) + L -= torch.sparse.FloatTensor(idx, ones, (V, V)) + + return L + + +def cot_laplacian( + verts: torch.Tensor, faces: torch.Tensor, eps: float = 1e-12 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns the Laplacian matrix with cotangent weights and the inverse of the + face areas. + + Args: + verts: tensor of shape (V, 3) containing the vertices of the graph + faces: tensor of shape (F, 3) containing the vertex indices of each face + Returns: + 2-element tuple containing + - **L**: Sparse FloatTensor of shape (V,V) for the Laplacian matrix. + Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes. + See the description above for more clarity. + - **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of + face areas containing each vertex + """ + V, F = verts.shape[0], faces.shape[0] + + face_verts = verts[faces] + v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] + + # Side lengths of each triangle, of shape (sum(F_n),) + # A is the side opposite v1, B is opposite v2, and C is opposite v3 + A = (v1 - v2).norm(dim=1) + B = (v0 - v2).norm(dim=1) + C = (v0 - v1).norm(dim=1) + + # Area of each triangle (with Heron's formula); shape is (sum(F_n),) + s = 0.5 * (A + B + C) + # note that the area can be negative (close to 0) causing nans after sqrt() + # we clip it to a small positive value + area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=eps).sqrt() + + # Compute cotangents of angles, of shape (sum(F_n), 3) + A2, B2, C2 = A * A, B * B, C * C + cota = (B2 + C2 - A2) / area + cotb = (A2 + C2 - B2) / area + cotc = (A2 + B2 - C2) / area + cot = torch.stack([cota, cotb, cotc], dim=1) + cot /= 4.0 + + # Construct a sparse matrix by basically doing: + # L[v1, v2] = cota + # L[v2, v0] = cotb + # L[v0, v1] = cotc + ii = faces[:, [1, 2, 0]] + jj = faces[:, [2, 0, 1]] + idx = torch.stack([ii, jj], dim=0).view(2, F * 3) + L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) + + # Make it symmetric; this means we are also setting + # L[v2, v1] = cota + # L[v0, v2] = cotb + # L[v1, v0] = cotc + L += L.t() + + # For each vertex, compute the sum of areas for triangles containing it. + idx = faces.view(-1) + inv_areas = torch.zeros(V, dtype=torch.float32, device=verts.device) + val = torch.stack([area] * 3, dim=1).view(-1) + inv_areas.scatter_add_(0, idx, val) + idx = inv_areas > 0 + inv_areas[idx] = 1.0 / inv_areas[idx] + inv_areas = inv_areas.view(-1, 1) + + return L, inv_areas + + +def norm_laplacian( + verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12 +) -> torch.Tensor: + """ + Norm laplacian computes a variant of the laplacian matrix which weights each + affinity with the normalized distance of the neighboring nodes. + More concretely, + L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes + + Args: + verts: tensor of shape (V, 3) containing the vertices of the graph + edges: tensor of shape (E, 2) containing the vertex indices of each edge + Returns: + L: Sparse FloatTensor of shape (V, V) + """ + edge_verts = verts[edges] # (E, 2, 3) + v0, v1 = edge_verts[:, 0], edge_verts[:, 1] + + # Side lengths of each edge, of shape (E,) + w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps) + + # Construct a sparse matrix by basically doing: + # L[v0, v1] = w01 + # L[v1, v0] = w01 + e01 = edges.t() # (2, E) + + V = verts.shape[0] + L = torch.sparse.FloatTensor(e01, w01, (V, V)) + L = L + L.t() + + return L diff --git a/pytorch3d/ops/mesh_filtering.py b/pytorch3d/ops/mesh_filtering.py index e662b01a..8d359650 100644 --- a/pytorch3d/ops/mesh_filtering.py +++ b/pytorch3d/ops/mesh_filtering.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +from pytorch3d.ops import norm_laplacian from pytorch3d.structures import Meshes, utils as struct_utils @@ -19,35 +20,6 @@ from pytorch3d.structures import Meshes, utils as struct_utils # ----------------------- Taubin Smoothing ----------------------- # -def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12): - """ - Norm laplacian computes a variant of the laplacian matrix which weights each - affinity with the normalized distance of the neighboring nodes. - More concretely, - L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes - - Args: - verts: tensor of shape (V, 3) containing the vertices of the graph - edges: tensor of shape (E, 2) containing the vertex indices of each edge - """ - edge_verts = verts[edges] # (E, 2, 3) - v0, v1 = edge_verts[:, 0], edge_verts[:, 1] - - # Side lengths of each edge, of shape (E,) - w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps) - - # Construct a sparse matrix by basically doing: - # L[v0, v1] = w01 - # L[v1, v0] = w01 - e01 = edges.t() # (2, E) - - V = verts.shape[0] - L = torch.sparse.FloatTensor(e01, w01, (V, V)) - L = L + L.t() - - return L - - def taubin_smoothing( meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10 ) -> Meshes: diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index f24877bf..37f8d315 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1142,6 +1142,8 @@ class Meshes: Sparse FloatTensor of shape (V, V) where V = sum(V_n) """ + from ..ops import laplacian + if not (refresh or self._laplacian_packed is None): return @@ -1153,39 +1155,8 @@ class Meshes: verts_packed = self.verts_packed() # (sum(V_n), 3) edges_packed = self.edges_packed() # (sum(E_n), 3) - V = verts_packed.shape[0] # sum(V_n) - e0, e1 = edges_packed.unbind(1) - - idx01 = torch.stack([e0, e1], dim=1) # (sum(E_n), 2) - idx10 = torch.stack([e1, e0], dim=1) # (sum(E_n), 2) - idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*sum(E_n)) - - # First, we construct the adjacency matrix, - # i.e. A[i, j] = 1 if (i,j) is an edge, or - # A[e0, e1] = 1 & A[e1, e0] = 1 - ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device) - A = torch.sparse.FloatTensor(idx, ones, (V, V)) - - # the sum of i-th row of A gives the degree of the i-th vertex - deg = torch.sparse.sum(A, dim=1).to_dense() - - # We construct the Laplacian matrix by adding the non diagonal values - # i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge - deg0 = deg[e0] - deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0) - deg1 = deg[e1] - deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1) - val = torch.cat([deg0, deg1]) - L = torch.sparse.FloatTensor(idx, val, (V, V)) - - # Then we add the diagonal values L[i, i] = -1. - idx = torch.arange(V, device=self.device) - idx = torch.stack([idx, idx], dim=0) - ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device) - L -= torch.sparse.FloatTensor(idx, ones, (V, V)) - - self._laplacian_packed = L + self._laplacian_packed = laplacian(verts_packed, edges_packed) def clone(self): """ diff --git a/tests/test_laplacian_matrices.py b/tests/test_laplacian_matrices.py new file mode 100644 index 00000000..33927d28 --- /dev/null +++ b/tests/test_laplacian_matrices.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin, get_random_cuda_device +from pytorch3d.ops import laplacian, norm_laplacian, cot_laplacian +from pytorch3d.structures.meshes import Meshes + + +class TestLaplacianMatrices(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + def init_mesh(self) -> Meshes: + V, F = 32, 64 + device = get_random_cuda_device() + # random vertices + verts = torch.rand((V, 3), dtype=torch.float32, device=device) + # random valid faces (no self circles, e.g. (v0, v0, v1)) + faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3] + faces = faces.to(device=device) + return Meshes(verts=[verts], faces=[faces]) + + def test_laplacian(self): + mesh = self.init_mesh() + verts = mesh.verts_packed() + edges = mesh.edges_packed() + V, E = verts.shape[0], edges.shape[0] + + L = laplacian(verts, edges) + + Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device) + for e in range(E): + e0, e1 = edges[e] + Lnaive[e0, e1] = 1 + # symetric + Lnaive[e1, e0] = 1 + + deg = Lnaive.sum(1).view(-1, 1) + deg[deg > 0] = 1.0 / deg[deg > 0] + Lnaive = Lnaive * deg + diag = torch.eye(V, dtype=torch.float32, device=mesh.device) + Lnaive.masked_fill_(diag > 0, -1) + + self.assertClose(L.to_dense(), Lnaive) + + def test_cot_laplacian(self): + mesh = self.init_mesh() + verts = mesh.verts_packed() + faces = mesh.faces_packed() + V, F = verts.shape[0], faces.shape[0] + + eps = 1e-12 + + L, inv_areas = cot_laplacian(verts, faces, eps=eps) + + Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device) + inv_areas_naive = torch.zeros((V, 1), dtype=torch.float32, device=verts.device) + + for f in faces: + v0 = verts[f[0], :] + v1 = verts[f[1], :] + v2 = verts[f[2], :] + A = (v1 - v2).norm() + B = (v0 - v2).norm() + C = (v0 - v1).norm() + s = 0.5 * (A + B + C) + + face_area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() + inv_areas_naive[f[0]] += face_area + inv_areas_naive[f[1]] += face_area + inv_areas_naive[f[2]] += face_area + + A2, B2, C2 = A * A, B * B, C * C + cota = (B2 + C2 - A2) / face_area / 4.0 + cotb = (A2 + C2 - B2) / face_area / 4.0 + cotc = (A2 + B2 - C2) / face_area / 4.0 + + Lnaive[f[1], f[2]] += cota + Lnaive[f[2], f[0]] += cotb + Lnaive[f[0], f[1]] += cotc + # symetric + Lnaive[f[2], f[1]] += cota + Lnaive[f[0], f[2]] += cotb + Lnaive[f[1], f[0]] += cotc + + idx = inv_areas_naive > 0 + inv_areas_naive[idx] = 1.0 / inv_areas_naive[idx] + + self.assertClose(inv_areas, inv_areas_naive) + self.assertClose(L.to_dense(), Lnaive) + + def test_norm_laplacian(self): + mesh = self.init_mesh() + verts = mesh.verts_packed() + edges = mesh.edges_packed() + V, E = verts.shape[0], edges.shape[0] + + eps = 1e-12 + + L = norm_laplacian(verts, edges, eps=eps) + + Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device) + for e in range(E): + e0, e1 = edges[e] + v0 = verts[e0] + v1 = verts[e1] + + w01 = 1.0 / ((v0 - v1).norm() + eps) + Lnaive[e0, e1] += w01 + Lnaive[e1, e0] += w01 + + self.assertClose(L.to_dense(), Lnaive) diff --git a/tests/test_mesh_filtering.py b/tests/test_mesh_filtering.py index f6a06d31..146e6bbd 100644 --- a/tests/test_mesh_filtering.py +++ b/tests/test_mesh_filtering.py @@ -10,7 +10,6 @@ import unittest import torch from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.ops import taubin_smoothing -from pytorch3d.ops.mesh_filtering import norm_laplacian from pytorch3d.structures import Meshes from pytorch3d.utils import ico_sphere @@ -40,39 +39,3 @@ class TestTaubinSmoothing(TestCaseMixin, unittest.TestCase): smooth_dist = (smooth_verts - ico_verts).norm(dim=-1).mean() dist = (verts - ico_verts).norm(dim=-1).mean() self.assertTrue(smooth_dist < dist) - - def test_norm_laplacian(self): - V = 32 - F = 64 - device = get_random_cuda_device() - # random vertices - verts = torch.rand((V, 3), dtype=torch.float32, device=device) - # random valid faces (no self circles, e.g. (v0, v0, v1)) - faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3] - faces = faces.to(device=device) - mesh = Meshes(verts=[verts], faces=[faces]) - edges = mesh.edges_packed() - - eps = 1e-12 - - L = norm_laplacian(verts, edges, eps=eps) - - Lnaive = torch.zeros((V, V), dtype=torch.float32, device=device) - for f in range(F): - f0, f1, f2 = faces[f] - v0 = verts[f0] - v1 = verts[f1] - v2 = verts[f2] - - w12 = 1.0 / ((v1 - v2).norm() + eps) - w02 = 1.0 / ((v0 - v2).norm() + eps) - w01 = 1.0 / ((v0 - v1).norm() + eps) - - Lnaive[f0, f1] = w01 - Lnaive[f1, f0] = w01 - Lnaive[f0, f2] = w02 - Lnaive[f2, f0] = w02 - Lnaive[f1, f2] = w12 - Lnaive[f2, f1] = w12 - - self.assertClose(L.to_dense(), Lnaive) diff --git a/tests/test_meshes.py b/tests/test_meshes.py index da9d78f0..f350a989 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -406,34 +406,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertFalse(newv.requires_grad) self.assertClose(newv, v) - def test_laplacian_packed(self): - def naive_laplacian_packed(meshes): - verts_packed = meshes.verts_packed() - edges_packed = meshes.edges_packed() - V = verts_packed.shape[0] - - L = torch.zeros((V, V), dtype=torch.float32, device=meshes.device) - for e in edges_packed: - L[e[0], e[1]] = 1 - # symetric - L[e[1], e[0]] = 1 - - deg = L.sum(1).view(-1, 1) - deg[deg > 0] = 1.0 / deg[deg > 0] - L = L * deg - diag = torch.eye(V, dtype=torch.float32, device=meshes.device) - L.masked_fill_(diag > 0, -1) - return L - - # Note that we don't test with random meshes for this case, as the - # definition of Laplacian is defined for simple graphs (aka valid meshes) - meshes = init_simple_mesh("cuda:0") - - lapl_naive = naive_laplacian_packed(meshes) - lapl = meshes.laplacian_packed().to_dense() - # check with naive - self.assertClose(lapl, lapl_naive) - def test_offset_verts(self): def naive_offset_verts(mesh, vert_offsets_packed): # new Meshes class