refactor laplacian matrices

Summary:
Refactor of all functions to compute laplacian matrices in one file.
Support for:
* Standard Laplacian
* Cotangent Laplacian
* Norm Laplacian

Reviewed By: nikhilaravi

Differential Revision: D29297466

fbshipit-source-id: b96b88915ce8ef0c2f5693ec9b179fd27b70abf9
This commit is contained in:
Georgia Gkioxari
2021-06-24 03:52:30 -07:00
committed by Facebook GitHub Bot
parent da9974b416
commit 07a5a68d50
8 changed files with 297 additions and 197 deletions

View File

@@ -1142,6 +1142,8 @@ class Meshes:
Sparse FloatTensor of shape (V, V) where V = sum(V_n)
"""
from ..ops import laplacian
if not (refresh or self._laplacian_packed is None):
return
@@ -1153,39 +1155,8 @@ class Meshes:
verts_packed = self.verts_packed() # (sum(V_n), 3)
edges_packed = self.edges_packed() # (sum(E_n), 3)
V = verts_packed.shape[0] # sum(V_n)
e0, e1 = edges_packed.unbind(1)
idx01 = torch.stack([e0, e1], dim=1) # (sum(E_n), 2)
idx10 = torch.stack([e1, e0], dim=1) # (sum(E_n), 2)
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*sum(E_n))
# First, we construct the adjacency matrix,
# i.e. A[i, j] = 1 if (i,j) is an edge, or
# A[e0, e1] = 1 & A[e1, e0] = 1
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
A = torch.sparse.FloatTensor(idx, ones, (V, V))
# the sum of i-th row of A gives the degree of the i-th vertex
deg = torch.sparse.sum(A, dim=1).to_dense()
# We construct the Laplacian matrix by adding the non diagonal values
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
deg0 = deg[e0]
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
deg1 = deg[e1]
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
val = torch.cat([deg0, deg1])
L = torch.sparse.FloatTensor(idx, val, (V, V))
# Then we add the diagonal values L[i, i] = -1.
idx = torch.arange(V, device=self.device)
idx = torch.stack([idx, idx], dim=0)
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
self._laplacian_packed = L
self._laplacian_packed = laplacian(verts_packed, edges_packed)
def clone(self):
"""