mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
refactor laplacian matrices
Summary: Refactor of all functions to compute laplacian matrices in one file. Support for: * Standard Laplacian * Cotangent Laplacian * Norm Laplacian Reviewed By: nikhilaravi Differential Revision: D29297466 fbshipit-source-id: b96b88915ce8ef0c2f5693ec9b179fd27b70abf9
This commit is contained in:
parent
da9974b416
commit
07a5a68d50
@ -6,6 +6,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops import cot_laplacian
|
||||
|
||||
|
||||
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
|
||||
@ -94,6 +95,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
|
||||
|
||||
N = len(meshes)
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
|
||||
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
|
||||
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
|
||||
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
|
||||
@ -106,7 +108,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
|
||||
if method == "uniform":
|
||||
L = meshes.laplacian_packed()
|
||||
elif method in ["cot", "cotcurv"]:
|
||||
L, inv_areas = laplacian_cot(meshes)
|
||||
L, inv_areas = cot_laplacian(verts_packed, faces_packed)
|
||||
if method == "cot":
|
||||
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
|
||||
idx = norm_w > 0
|
||||
@ -127,73 +129,3 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
|
||||
|
||||
loss = loss * weights
|
||||
return loss.sum() / N
|
||||
|
||||
|
||||
def laplacian_cot(meshes):
|
||||
"""
|
||||
Returns the Laplacian matrix with cotangent weights and the inverse of the
|
||||
face areas.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object with a batch of meshes.
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
|
||||
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
|
||||
See the description above for more clarity.
|
||||
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
|
||||
face areas containing each vertex
|
||||
"""
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
|
||||
# V = sum(V_n), F = sum(F_n)
|
||||
V, F = verts_packed.shape[0], faces_packed.shape[0]
|
||||
|
||||
face_verts = verts_packed[faces_packed]
|
||||
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
|
||||
|
||||
# Side lengths of each triangle, of shape (sum(F_n),)
|
||||
# A is the side opposite v1, B is opposite v2, and C is opposite v3
|
||||
A = (v1 - v2).norm(dim=1)
|
||||
B = (v0 - v2).norm(dim=1)
|
||||
C = (v0 - v1).norm(dim=1)
|
||||
|
||||
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
|
||||
s = 0.5 * (A + B + C)
|
||||
# note that the area can be negative (close to 0) causing nans after sqrt()
|
||||
# we clip it to a small positive value
|
||||
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
|
||||
|
||||
# Compute cotangents of angles, of shape (sum(F_n), 3)
|
||||
A2, B2, C2 = A * A, B * B, C * C
|
||||
cota = (B2 + C2 - A2) / area
|
||||
cotb = (A2 + C2 - B2) / area
|
||||
cotc = (A2 + B2 - C2) / area
|
||||
cot = torch.stack([cota, cotb, cotc], dim=1)
|
||||
cot /= 4.0
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v1, v2] = cota
|
||||
# L[v2, v0] = cotb
|
||||
# L[v0, v1] = cotc
|
||||
ii = faces_packed[:, [1, 2, 0]]
|
||||
jj = faces_packed[:, [2, 0, 1]]
|
||||
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
|
||||
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
|
||||
|
||||
# Make it symmetric; this means we are also setting
|
||||
# L[v2, v1] = cota
|
||||
# L[v0, v2] = cotb
|
||||
# L[v1, v0] = cotc
|
||||
L += L.t()
|
||||
|
||||
# For each vertex, compute the sum of areas for triangles containing it.
|
||||
idx = faces_packed.view(-1)
|
||||
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
|
||||
val = torch.stack([area] * 3, dim=1).view(-1)
|
||||
inv_areas.scatter_add_(0, idx, val)
|
||||
idx = inv_areas > 0
|
||||
inv_areas[idx] = 1.0 / inv_areas[idx]
|
||||
inv_areas = inv_areas.view(-1, 1)
|
||||
|
||||
return L, inv_areas
|
||||
|
@ -9,6 +9,7 @@ from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .interp_face_attrs import interpolate_face_attributes
|
||||
from .knn import knn_gather, knn_points
|
||||
from .laplacian_matrices import laplacian, cot_laplacian, norm_laplacian
|
||||
from .mesh_face_areas_normals import mesh_face_areas_normals
|
||||
from .mesh_filtering import taubin_smoothing
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
|
170
pytorch3d/ops/laplacian_matrices.py
Normal file
170
pytorch3d/ops/laplacian_matrices.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# ------------------------ Laplacian Matrices ------------------------ #
|
||||
# This file contains implementations of differentiable laplacian matrices.
|
||||
# These include
|
||||
# 1) Standard Laplacian matrix
|
||||
# 2) Cotangent Laplacian matrix
|
||||
# 3) Norm Laplacian matrix
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the laplacian matrix.
|
||||
The definition of the laplacian is
|
||||
L[i, j] = -1 , if i == j
|
||||
L[i, j] = 1 / deg(i) , if (i, j) is an edge
|
||||
L[i, j] = 0 , otherwise
|
||||
where deg(i) is the degree of the i-th vertex in the graph.
|
||||
|
||||
Args:
|
||||
verts: tensor of shape (V, 3) containing the vertices of the graph
|
||||
edges: tensor of shape (E, 2) containing the vertex indices of each edge
|
||||
Returns:
|
||||
L: Sparse FloatTensor of shape (V, V)
|
||||
"""
|
||||
V = verts.shape[0]
|
||||
|
||||
e0, e1 = edges.unbind(1)
|
||||
|
||||
idx01 = torch.stack([e0, e1], dim=1) # (E, 2)
|
||||
idx10 = torch.stack([e1, e0], dim=1) # (E, 2)
|
||||
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*E)
|
||||
|
||||
# First, we construct the adjacency matrix,
|
||||
# i.e. A[i, j] = 1 if (i,j) is an edge, or
|
||||
# A[e0, e1] = 1 & A[e1, e0] = 1
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
|
||||
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
|
||||
# the sum of i-th row of A gives the degree of the i-th vertex
|
||||
deg = torch.sparse.sum(A, dim=1).to_dense()
|
||||
|
||||
# We construct the Laplacian matrix by adding the non diagonal values
|
||||
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
|
||||
deg0 = deg[e0]
|
||||
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
|
||||
deg1 = deg[e1]
|
||||
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
|
||||
val = torch.cat([deg0, deg1])
|
||||
L = torch.sparse.FloatTensor(idx, val, (V, V))
|
||||
|
||||
# Then we add the diagonal values L[i, i] = -1.
|
||||
idx = torch.arange(V, device=verts.device)
|
||||
idx = torch.stack([idx, idx], dim=0)
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
|
||||
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
|
||||
return L
|
||||
|
||||
|
||||
def cot_laplacian(
|
||||
verts: torch.Tensor, faces: torch.Tensor, eps: float = 1e-12
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns the Laplacian matrix with cotangent weights and the inverse of the
|
||||
face areas.
|
||||
|
||||
Args:
|
||||
verts: tensor of shape (V, 3) containing the vertices of the graph
|
||||
faces: tensor of shape (F, 3) containing the vertex indices of each face
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
- **L**: Sparse FloatTensor of shape (V,V) for the Laplacian matrix.
|
||||
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
|
||||
See the description above for more clarity.
|
||||
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
|
||||
face areas containing each vertex
|
||||
"""
|
||||
V, F = verts.shape[0], faces.shape[0]
|
||||
|
||||
face_verts = verts[faces]
|
||||
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
|
||||
|
||||
# Side lengths of each triangle, of shape (sum(F_n),)
|
||||
# A is the side opposite v1, B is opposite v2, and C is opposite v3
|
||||
A = (v1 - v2).norm(dim=1)
|
||||
B = (v0 - v2).norm(dim=1)
|
||||
C = (v0 - v1).norm(dim=1)
|
||||
|
||||
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
|
||||
s = 0.5 * (A + B + C)
|
||||
# note that the area can be negative (close to 0) causing nans after sqrt()
|
||||
# we clip it to a small positive value
|
||||
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=eps).sqrt()
|
||||
|
||||
# Compute cotangents of angles, of shape (sum(F_n), 3)
|
||||
A2, B2, C2 = A * A, B * B, C * C
|
||||
cota = (B2 + C2 - A2) / area
|
||||
cotb = (A2 + C2 - B2) / area
|
||||
cotc = (A2 + B2 - C2) / area
|
||||
cot = torch.stack([cota, cotb, cotc], dim=1)
|
||||
cot /= 4.0
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v1, v2] = cota
|
||||
# L[v2, v0] = cotb
|
||||
# L[v0, v1] = cotc
|
||||
ii = faces[:, [1, 2, 0]]
|
||||
jj = faces[:, [2, 0, 1]]
|
||||
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
|
||||
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
|
||||
|
||||
# Make it symmetric; this means we are also setting
|
||||
# L[v2, v1] = cota
|
||||
# L[v0, v2] = cotb
|
||||
# L[v1, v0] = cotc
|
||||
L += L.t()
|
||||
|
||||
# For each vertex, compute the sum of areas for triangles containing it.
|
||||
idx = faces.view(-1)
|
||||
inv_areas = torch.zeros(V, dtype=torch.float32, device=verts.device)
|
||||
val = torch.stack([area] * 3, dim=1).view(-1)
|
||||
inv_areas.scatter_add_(0, idx, val)
|
||||
idx = inv_areas > 0
|
||||
inv_areas[idx] = 1.0 / inv_areas[idx]
|
||||
inv_areas = inv_areas.view(-1, 1)
|
||||
|
||||
return L, inv_areas
|
||||
|
||||
|
||||
def norm_laplacian(
|
||||
verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Norm laplacian computes a variant of the laplacian matrix which weights each
|
||||
affinity with the normalized distance of the neighboring nodes.
|
||||
More concretely,
|
||||
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
|
||||
|
||||
Args:
|
||||
verts: tensor of shape (V, 3) containing the vertices of the graph
|
||||
edges: tensor of shape (E, 2) containing the vertex indices of each edge
|
||||
Returns:
|
||||
L: Sparse FloatTensor of shape (V, V)
|
||||
"""
|
||||
edge_verts = verts[edges] # (E, 2, 3)
|
||||
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
|
||||
|
||||
# Side lengths of each edge, of shape (E,)
|
||||
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v0, v1] = w01
|
||||
# L[v1, v0] = w01
|
||||
e01 = edges.t() # (2, E)
|
||||
|
||||
V = verts.shape[0]
|
||||
L = torch.sparse.FloatTensor(e01, w01, (V, V))
|
||||
L = L + L.t()
|
||||
|
||||
return L
|
@ -5,6 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops import norm_laplacian
|
||||
from pytorch3d.structures import Meshes, utils as struct_utils
|
||||
|
||||
|
||||
@ -19,35 +20,6 @@ from pytorch3d.structures import Meshes, utils as struct_utils
|
||||
# ----------------------- Taubin Smoothing ----------------------- #
|
||||
|
||||
|
||||
def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12):
|
||||
"""
|
||||
Norm laplacian computes a variant of the laplacian matrix which weights each
|
||||
affinity with the normalized distance of the neighboring nodes.
|
||||
More concretely,
|
||||
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
|
||||
|
||||
Args:
|
||||
verts: tensor of shape (V, 3) containing the vertices of the graph
|
||||
edges: tensor of shape (E, 2) containing the vertex indices of each edge
|
||||
"""
|
||||
edge_verts = verts[edges] # (E, 2, 3)
|
||||
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
|
||||
|
||||
# Side lengths of each edge, of shape (E,)
|
||||
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v0, v1] = w01
|
||||
# L[v1, v0] = w01
|
||||
e01 = edges.t() # (2, E)
|
||||
|
||||
V = verts.shape[0]
|
||||
L = torch.sparse.FloatTensor(e01, w01, (V, V))
|
||||
L = L + L.t()
|
||||
|
||||
return L
|
||||
|
||||
|
||||
def taubin_smoothing(
|
||||
meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10
|
||||
) -> Meshes:
|
||||
|
@ -1142,6 +1142,8 @@ class Meshes:
|
||||
Sparse FloatTensor of shape (V, V) where V = sum(V_n)
|
||||
|
||||
"""
|
||||
from ..ops import laplacian
|
||||
|
||||
if not (refresh or self._laplacian_packed is None):
|
||||
return
|
||||
|
||||
@ -1153,39 +1155,8 @@ class Meshes:
|
||||
|
||||
verts_packed = self.verts_packed() # (sum(V_n), 3)
|
||||
edges_packed = self.edges_packed() # (sum(E_n), 3)
|
||||
V = verts_packed.shape[0] # sum(V_n)
|
||||
|
||||
e0, e1 = edges_packed.unbind(1)
|
||||
|
||||
idx01 = torch.stack([e0, e1], dim=1) # (sum(E_n), 2)
|
||||
idx10 = torch.stack([e1, e0], dim=1) # (sum(E_n), 2)
|
||||
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*sum(E_n))
|
||||
|
||||
# First, we construct the adjacency matrix,
|
||||
# i.e. A[i, j] = 1 if (i,j) is an edge, or
|
||||
# A[e0, e1] = 1 & A[e1, e0] = 1
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
|
||||
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
|
||||
# the sum of i-th row of A gives the degree of the i-th vertex
|
||||
deg = torch.sparse.sum(A, dim=1).to_dense()
|
||||
|
||||
# We construct the Laplacian matrix by adding the non diagonal values
|
||||
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
|
||||
deg0 = deg[e0]
|
||||
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
|
||||
deg1 = deg[e1]
|
||||
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
|
||||
val = torch.cat([deg0, deg1])
|
||||
L = torch.sparse.FloatTensor(idx, val, (V, V))
|
||||
|
||||
# Then we add the diagonal values L[i, i] = -1.
|
||||
idx = torch.arange(V, device=self.device)
|
||||
idx = torch.stack([idx, idx], dim=0)
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
|
||||
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
|
||||
self._laplacian_packed = L
|
||||
self._laplacian_packed = laplacian(verts_packed, edges_packed)
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
|
119
tests/test_laplacian_matrices.py
Normal file
119
tests/test_laplacian_matrices.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import laplacian, norm_laplacian, cot_laplacian
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
|
||||
class TestLaplacianMatrices(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def init_mesh(self) -> Meshes:
|
||||
V, F = 32, 64
|
||||
device = get_random_cuda_device()
|
||||
# random vertices
|
||||
verts = torch.rand((V, 3), dtype=torch.float32, device=device)
|
||||
# random valid faces (no self circles, e.g. (v0, v0, v1))
|
||||
faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3]
|
||||
faces = faces.to(device=device)
|
||||
return Meshes(verts=[verts], faces=[faces])
|
||||
|
||||
def test_laplacian(self):
|
||||
mesh = self.init_mesh()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
V, E = verts.shape[0], edges.shape[0]
|
||||
|
||||
L = laplacian(verts, edges)
|
||||
|
||||
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
|
||||
for e in range(E):
|
||||
e0, e1 = edges[e]
|
||||
Lnaive[e0, e1] = 1
|
||||
# symetric
|
||||
Lnaive[e1, e0] = 1
|
||||
|
||||
deg = Lnaive.sum(1).view(-1, 1)
|
||||
deg[deg > 0] = 1.0 / deg[deg > 0]
|
||||
Lnaive = Lnaive * deg
|
||||
diag = torch.eye(V, dtype=torch.float32, device=mesh.device)
|
||||
Lnaive.masked_fill_(diag > 0, -1)
|
||||
|
||||
self.assertClose(L.to_dense(), Lnaive)
|
||||
|
||||
def test_cot_laplacian(self):
|
||||
mesh = self.init_mesh()
|
||||
verts = mesh.verts_packed()
|
||||
faces = mesh.faces_packed()
|
||||
V, F = verts.shape[0], faces.shape[0]
|
||||
|
||||
eps = 1e-12
|
||||
|
||||
L, inv_areas = cot_laplacian(verts, faces, eps=eps)
|
||||
|
||||
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
|
||||
inv_areas_naive = torch.zeros((V, 1), dtype=torch.float32, device=verts.device)
|
||||
|
||||
for f in faces:
|
||||
v0 = verts[f[0], :]
|
||||
v1 = verts[f[1], :]
|
||||
v2 = verts[f[2], :]
|
||||
A = (v1 - v2).norm()
|
||||
B = (v0 - v2).norm()
|
||||
C = (v0 - v1).norm()
|
||||
s = 0.5 * (A + B + C)
|
||||
|
||||
face_area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
|
||||
inv_areas_naive[f[0]] += face_area
|
||||
inv_areas_naive[f[1]] += face_area
|
||||
inv_areas_naive[f[2]] += face_area
|
||||
|
||||
A2, B2, C2 = A * A, B * B, C * C
|
||||
cota = (B2 + C2 - A2) / face_area / 4.0
|
||||
cotb = (A2 + C2 - B2) / face_area / 4.0
|
||||
cotc = (A2 + B2 - C2) / face_area / 4.0
|
||||
|
||||
Lnaive[f[1], f[2]] += cota
|
||||
Lnaive[f[2], f[0]] += cotb
|
||||
Lnaive[f[0], f[1]] += cotc
|
||||
# symetric
|
||||
Lnaive[f[2], f[1]] += cota
|
||||
Lnaive[f[0], f[2]] += cotb
|
||||
Lnaive[f[1], f[0]] += cotc
|
||||
|
||||
idx = inv_areas_naive > 0
|
||||
inv_areas_naive[idx] = 1.0 / inv_areas_naive[idx]
|
||||
|
||||
self.assertClose(inv_areas, inv_areas_naive)
|
||||
self.assertClose(L.to_dense(), Lnaive)
|
||||
|
||||
def test_norm_laplacian(self):
|
||||
mesh = self.init_mesh()
|
||||
verts = mesh.verts_packed()
|
||||
edges = mesh.edges_packed()
|
||||
V, E = verts.shape[0], edges.shape[0]
|
||||
|
||||
eps = 1e-12
|
||||
|
||||
L = norm_laplacian(verts, edges, eps=eps)
|
||||
|
||||
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=verts.device)
|
||||
for e in range(E):
|
||||
e0, e1 = edges[e]
|
||||
v0 = verts[e0]
|
||||
v1 = verts[e1]
|
||||
|
||||
w01 = 1.0 / ((v0 - v1).norm() + eps)
|
||||
Lnaive[e0, e1] += w01
|
||||
Lnaive[e1, e0] += w01
|
||||
|
||||
self.assertClose(L.to_dense(), Lnaive)
|
@ -10,7 +10,6 @@ import unittest
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d.ops import taubin_smoothing
|
||||
from pytorch3d.ops.mesh_filtering import norm_laplacian
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
@ -40,39 +39,3 @@ class TestTaubinSmoothing(TestCaseMixin, unittest.TestCase):
|
||||
smooth_dist = (smooth_verts - ico_verts).norm(dim=-1).mean()
|
||||
dist = (verts - ico_verts).norm(dim=-1).mean()
|
||||
self.assertTrue(smooth_dist < dist)
|
||||
|
||||
def test_norm_laplacian(self):
|
||||
V = 32
|
||||
F = 64
|
||||
device = get_random_cuda_device()
|
||||
# random vertices
|
||||
verts = torch.rand((V, 3), dtype=torch.float32, device=device)
|
||||
# random valid faces (no self circles, e.g. (v0, v0, v1))
|
||||
faces = torch.stack([torch.randperm(V) for f in range(F)], dim=0)[:, :3]
|
||||
faces = faces.to(device=device)
|
||||
mesh = Meshes(verts=[verts], faces=[faces])
|
||||
edges = mesh.edges_packed()
|
||||
|
||||
eps = 1e-12
|
||||
|
||||
L = norm_laplacian(verts, edges, eps=eps)
|
||||
|
||||
Lnaive = torch.zeros((V, V), dtype=torch.float32, device=device)
|
||||
for f in range(F):
|
||||
f0, f1, f2 = faces[f]
|
||||
v0 = verts[f0]
|
||||
v1 = verts[f1]
|
||||
v2 = verts[f2]
|
||||
|
||||
w12 = 1.0 / ((v1 - v2).norm() + eps)
|
||||
w02 = 1.0 / ((v0 - v2).norm() + eps)
|
||||
w01 = 1.0 / ((v0 - v1).norm() + eps)
|
||||
|
||||
Lnaive[f0, f1] = w01
|
||||
Lnaive[f1, f0] = w01
|
||||
Lnaive[f0, f2] = w02
|
||||
Lnaive[f2, f0] = w02
|
||||
Lnaive[f1, f2] = w12
|
||||
Lnaive[f2, f1] = w12
|
||||
|
||||
self.assertClose(L.to_dense(), Lnaive)
|
||||
|
@ -406,34 +406,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertFalse(newv.requires_grad)
|
||||
self.assertClose(newv, v)
|
||||
|
||||
def test_laplacian_packed(self):
|
||||
def naive_laplacian_packed(meshes):
|
||||
verts_packed = meshes.verts_packed()
|
||||
edges_packed = meshes.edges_packed()
|
||||
V = verts_packed.shape[0]
|
||||
|
||||
L = torch.zeros((V, V), dtype=torch.float32, device=meshes.device)
|
||||
for e in edges_packed:
|
||||
L[e[0], e[1]] = 1
|
||||
# symetric
|
||||
L[e[1], e[0]] = 1
|
||||
|
||||
deg = L.sum(1).view(-1, 1)
|
||||
deg[deg > 0] = 1.0 / deg[deg > 0]
|
||||
L = L * deg
|
||||
diag = torch.eye(V, dtype=torch.float32, device=meshes.device)
|
||||
L.masked_fill_(diag > 0, -1)
|
||||
return L
|
||||
|
||||
# Note that we don't test with random meshes for this case, as the
|
||||
# definition of Laplacian is defined for simple graphs (aka valid meshes)
|
||||
meshes = init_simple_mesh("cuda:0")
|
||||
|
||||
lapl_naive = naive_laplacian_packed(meshes)
|
||||
lapl = meshes.laplacian_packed().to_dense()
|
||||
# check with naive
|
||||
self.assertClose(lapl, lapl_naive)
|
||||
|
||||
def test_offset_verts(self):
|
||||
def naive_offset_verts(mesh, vert_offsets_packed):
|
||||
# new Meshes class
|
||||
|
Loading…
x
Reference in New Issue
Block a user