refactor laplacian matrices

Summary:
Refactor of all functions to compute laplacian matrices in one file.
Support for:
* Standard Laplacian
* Cotangent Laplacian
* Norm Laplacian

Reviewed By: nikhilaravi

Differential Revision: D29297466

fbshipit-source-id: b96b88915ce8ef0c2f5693ec9b179fd27b70abf9
This commit is contained in:
Georgia Gkioxari
2021-06-24 03:52:30 -07:00
committed by Facebook GitHub Bot
parent da9974b416
commit 07a5a68d50
8 changed files with 297 additions and 197 deletions

View File

@@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import laplacian, norm_laplacian, cot_laplacian
from pytorch3d.structures.meshes import Meshes
class TestLaplacianMatrices(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
def init_mesh(self) -> Meshes:
V, F = 32, 64
device = get_random_cuda_device()
# random vertices
verts = torch.rand((V, 3), dtype=torch.float32, device=device)
# random valid faces (no self circles, e.g. (v0, v0, v1))
faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3]
faces = faces.to(device=device)
return Meshes(verts=[verts], faces=[faces])
def test_laplacian(self):
mesh = self.init_mesh()
verts = mesh.verts_packed()
edges = mesh.edges_packed()
V, E = verts.shape[0], edges.shape[0]
L = laplacian(verts, edges)
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
for e in range(E):
e0, e1 = edges[e]
Lnaive[e0, e1] = 1
# symetric
Lnaive[e1, e0] = 1
deg = Lnaive.sum(1).view(-1, 1)
deg[deg > 0] = 1.0 / deg[deg > 0]
Lnaive = Lnaive * deg
diag = torch.eye(V, dtype=torch.float32, device=mesh.device)
Lnaive.masked_fill_(diag > 0, -1)
self.assertClose(L.to_dense(), Lnaive)
def test_cot_laplacian(self):
mesh = self.init_mesh()
verts = mesh.verts_packed()
faces = mesh.faces_packed()
V, F = verts.shape[0], faces.shape[0]
eps = 1e-12
L, inv_areas = cot_laplacian(verts, faces, eps=eps)
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
inv_areas_naive = torch.zeros((V, 1), dtype=torch.float32, device=verts.device)
for f in faces:
v0 = verts[f[0], :]
v1 = verts[f[1], :]
v2 = verts[f[2], :]
A = (v1 - v2).norm()
B = (v0 - v2).norm()
C = (v0 - v1).norm()
s = 0.5 * (A + B + C)
face_area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
inv_areas_naive[f[0]] += face_area
inv_areas_naive[f[1]] += face_area
inv_areas_naive[f[2]] += face_area
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / face_area / 4.0
cotb = (A2 + C2 - B2) / face_area / 4.0
cotc = (A2 + B2 - C2) / face_area / 4.0
Lnaive[f[1], f[2]] += cota
Lnaive[f[2], f[0]] += cotb
Lnaive[f[0], f[1]] += cotc
# symetric
Lnaive[f[2], f[1]] += cota
Lnaive[f[0], f[2]] += cotb
Lnaive[f[1], f[0]] += cotc
idx = inv_areas_naive > 0
inv_areas_naive[idx] = 1.0 / inv_areas_naive[idx]
self.assertClose(inv_areas, inv_areas_naive)
self.assertClose(L.to_dense(), Lnaive)
def test_norm_laplacian(self):
mesh = self.init_mesh()
verts = mesh.verts_packed()
edges = mesh.edges_packed()
V, E = verts.shape[0], edges.shape[0]
eps = 1e-12
L = norm_laplacian(verts, edges, eps=eps)
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
for e in range(E):
e0, e1 = edges[e]
v0 = verts[e0]
v1 = verts[e1]
w01 = 1.0 / ((v0 - v1).norm() + eps)
Lnaive[e0, e1] += w01
Lnaive[e1, e0] += w01
self.assertClose(L.to_dense(), Lnaive)

View File

@@ -10,7 +10,6 @@ import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d.ops import taubin_smoothing
from pytorch3d.ops.mesh_filtering import norm_laplacian
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
@@ -40,39 +39,3 @@ class TestTaubinSmoothing(TestCaseMixin, unittest.TestCase):
smooth_dist = (smooth_verts - ico_verts).norm(dim=-1).mean()
dist = (verts - ico_verts).norm(dim=-1).mean()
self.assertTrue(smooth_dist < dist)
def test_norm_laplacian(self):
V = 32
F = 64
device = get_random_cuda_device()
# random vertices
verts = torch.rand((V, 3), dtype=torch.float32, device=device)
# random valid faces (no self circles, e.g. (v0, v0, v1))
faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3]
faces = faces.to(device=device)
mesh = Meshes(verts=[verts], faces=[faces])
edges = mesh.edges_packed()
eps = 1e-12
L = norm_laplacian(verts, edges, eps=eps)
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=device)
for f in range(F):
f0, f1, f2 = faces[f]
v0 = verts[f0]
v1 = verts[f1]
v2 = verts[f2]
w12 = 1.0 / ((v1 - v2).norm() + eps)
w02 = 1.0 / ((v0 - v2).norm() + eps)
w01 = 1.0 / ((v0 - v1).norm() + eps)
Lnaive[f0, f1] = w01
Lnaive[f1, f0] = w01
Lnaive[f0, f2] = w02
Lnaive[f2, f0] = w02
Lnaive[f1, f2] = w12
Lnaive[f2, f1] = w12
self.assertClose(L.to_dense(), Lnaive)

View File

@@ -406,34 +406,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertFalse(newv.requires_grad)
self.assertClose(newv, v)
def test_laplacian_packed(self):
def naive_laplacian_packed(meshes):
verts_packed = meshes.verts_packed()
edges_packed = meshes.edges_packed()
V = verts_packed.shape[0]
L = torch.zeros((V, V), dtype=torch.float32, device=meshes.device)
for e in edges_packed:
L[e[0], e[1]] = 1
# symetric
L[e[1], e[0]] = 1
deg = L.sum(1).view(-1, 1)
deg[deg > 0] = 1.0 / deg[deg > 0]
L = L * deg
diag = torch.eye(V, dtype=torch.float32, device=meshes.device)
L.masked_fill_(diag > 0, -1)
return L
# Note that we don't test with random meshes for this case, as the
# definition of Laplacian is defined for simple graphs (aka valid meshes)
meshes = init_simple_mesh("cuda:0")
lapl_naive = naive_laplacian_packed(meshes)
lapl = meshes.laplacian_packed().to_dense()
# check with naive
self.assertClose(lapl, lapl_naive)
def test_offset_verts(self):
def naive_offset_verts(mesh, vert_offsets_packed):
# new Meshes class