mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
9
pytorch3d/loss/__init__.py
Normal file
9
pytorch3d/loss/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .chamfer import chamfer_distance
|
||||
from .mesh_edge_loss import mesh_edge_loss
|
||||
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
|
||||
from .mesh_normal_consistency import mesh_normal_consistency
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
152
pytorch3d/loss/chamfer.py
Normal file
152
pytorch3d/loss/chamfer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
|
||||
|
||||
|
||||
def _validate_chamfer_reduction_inputs(
|
||||
batch_reduction: str, point_reduction: str
|
||||
):
|
||||
"""Check the requested reductions are valid.
|
||||
|
||||
Args:
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
"""
|
||||
if batch_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError(
|
||||
'batch_reduction must be one of ["none", "mean", "sum"]'
|
||||
)
|
||||
if point_reduction not in ["none", "mean", "sum"]:
|
||||
raise ValueError(
|
||||
'point_reduction must be one of ["none", "mean", "sum"]'
|
||||
)
|
||||
if batch_reduction == "none" and point_reduction == "none":
|
||||
raise ValueError(
|
||||
'batch_reduction and point_reduction cannot both be "none".'
|
||||
)
|
||||
|
||||
|
||||
def chamfer_distance(
|
||||
x,
|
||||
y,
|
||||
x_normals=None,
|
||||
y_normals=None,
|
||||
weights=None,
|
||||
batch_reduction: str = "mean",
|
||||
point_reduction: str = "mean",
|
||||
):
|
||||
"""
|
||||
Chamfer distance between two pointclouds x and y.
|
||||
|
||||
Args:
|
||||
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
|
||||
with P1 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
|
||||
with P2 points in each batch element, batch size N and feature
|
||||
dimension D.
|
||||
x_normals: Optional FloatTensor of shape (N, P1, D).
|
||||
y_normals: Optional FloatTensor of shape (N, P2, D).
|
||||
weights: Optional FloatTensor of shape (N,) giving weights for
|
||||
batch elements for reduction operation.
|
||||
batch_reduction: Reduction operation to apply for the loss across the
|
||||
batch, can be one of ["none", "mean", "sum"].
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["none", "mean", "sum"].
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
|
||||
- **loss**: Tensor giving the reduced distance between the pointclouds
|
||||
in x and the pointclouds in y.
|
||||
- **loss_normals**: Tensor giving the reduced cosine distance of normals
|
||||
between pointclouds in x and pointclouds in y. Returns None if
|
||||
x_normals and y_normals are None.
|
||||
"""
|
||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||
|
||||
N, P1, D = x.shape
|
||||
P2 = y.shape[1]
|
||||
|
||||
if y.shape[0] != N or y.shape[2] != D:
|
||||
raise ValueError("y does not have the correct shape.")
|
||||
if weights is not None:
|
||||
if weights.size(0) != N:
|
||||
raise ValueError("weights must be of shape (N,).")
|
||||
if not (weights >= 0).all():
|
||||
raise ValueError("weights can not be nonnegative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
return (
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
)
|
||||
return (
|
||||
(x.sum((1, 2)) * weights) * 0.0,
|
||||
(x.sum((1, 2)) * weights) * 0.0,
|
||||
)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
cham_norm_x = x.new_zeros(())
|
||||
cham_norm_y = x.new_zeros(())
|
||||
|
||||
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
|
||||
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
|
||||
|
||||
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
|
||||
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
|
||||
|
||||
if weights is not None:
|
||||
cham_x *= weights.view(N, 1)
|
||||
cham_y *= weights.view(N, 1)
|
||||
|
||||
if return_normals:
|
||||
cham_norm_x = 1 - torch.abs(
|
||||
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
cham_norm_y = 1 - torch.abs(
|
||||
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
||||
)
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_y *= weights.view(N, 1)
|
||||
|
||||
if point_reduction != "none":
|
||||
# If not 'none' then either 'sum' or 'mean'.
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
cham_x /= P1
|
||||
cham_y /= P2
|
||||
if return_normals:
|
||||
cham_norm_x /= P1
|
||||
cham_norm_y /= P2
|
||||
|
||||
if batch_reduction != "none":
|
||||
cham_x = cham_x.sum()
|
||||
cham_y = cham_y.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
cham_norm_y = cham_norm_y.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else N
|
||||
cham_x /= div
|
||||
cham_y /= div
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
cham_norm_y /= div
|
||||
|
||||
cham_dist = cham_x + cham_y
|
||||
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
|
||||
|
||||
return cham_dist, cham_normals
|
||||
47
pytorch3d/loss/mesh_edge_loss.py
Normal file
47
pytorch3d/loss/mesh_edge_loss.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mesh_edge_loss(meshes, target_length: float = 0.0):
|
||||
"""
|
||||
Computes mesh edge length regularization loss averaged across all meshes
|
||||
in a batch. Each edge contributes equally to the final loss, regardless of
|
||||
numbers of edges per mesh in the batch by weighting each mesh with the
|
||||
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
|
||||
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
|
||||
contribute to the final loss.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object with a batch of meshes.
|
||||
target_length: Resting value for the edge length.
|
||||
|
||||
Returns:
|
||||
loss: Average loss across the batch. Returns 0 if meshes contains
|
||||
no meshes or all empty meshes.
|
||||
"""
|
||||
if meshes.isempty():
|
||||
return torch.tensor(
|
||||
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
|
||||
)
|
||||
|
||||
N = len(meshes)
|
||||
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
|
||||
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
|
||||
|
||||
# Determine the weight for each edge based on the number of edges in the
|
||||
# mesh it corresponds to.
|
||||
# TODO (nikhilar) Find a faster way of computing the weights for each edge
|
||||
# as this is currently a bottleneck for meshes with a large number of faces.
|
||||
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
|
||||
weights = 1.0 / weights.float()
|
||||
|
||||
verts_edges = verts_packed[edges_packed]
|
||||
v0, v1 = verts_edges.unbind(1)
|
||||
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
|
||||
loss = loss * weights
|
||||
|
||||
return loss.sum() / N
|
||||
195
pytorch3d/loss/mesh_laplacian_smoothing.py
Normal file
195
pytorch3d/loss/mesh_laplacian_smoothing.py
Normal file
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
|
||||
r"""
|
||||
Computes the laplacian smoothing objective for a batch of meshes.
|
||||
This function supports three variants of Laplacian smoothing,
|
||||
namely with uniform weights("uniform"), with cotangent weights ("cot"),
|
||||
and cotangent cuvature ("cotcurv").For more details read [1, 2].
|
||||
|
||||
Args:
|
||||
meshes: Meshes object with a batch of meshes.
|
||||
method: str specifying the method for the laplacian.
|
||||
Returns:
|
||||
loss: Average laplacian smoothing loss across the batch.
|
||||
Returns 0 if meshes contains no meshes or all empty meshes.
|
||||
|
||||
Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3.
|
||||
The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors:
|
||||
for a uniform Laplacian, LuV[i] points to the centroid of its neighboring
|
||||
vertices, a cotangent Laplacian LcV[i] is known to be an approximation of
|
||||
the surface normal, while the curvature variant LckV[i] scales the normals
|
||||
by the discrete mean curvature. For vertex i, assume S[i] is the set of
|
||||
neighboring vertices to i, a_ij and b_ij are the "outside" angles in the
|
||||
two triangles connecting vertex v_i and its neighboring vertex v_j
|
||||
for j in S[i], as seen in the diagram below.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a_ij
|
||||
/\
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
v_i /________\ v_j
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
\/
|
||||
b_ij
|
||||
|
||||
The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i)
|
||||
For the uniform variant, w_ij = 1 / |S[i]|
|
||||
For the cotangent variant,
|
||||
w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik)
|
||||
For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i])
|
||||
where A[i] is the sum of the areas of all triangles containing vertex v_i.
|
||||
|
||||
There is a nice trigonometry identity to compute cotangents. Consider a triangle
|
||||
with side lengths A, B, C and angles a, b, c.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
c
|
||||
/|\
|
||||
/ | \
|
||||
/ | \
|
||||
B / H| \ A
|
||||
/ | \
|
||||
/ | \
|
||||
/a_____|_____b\
|
||||
C
|
||||
|
||||
Then cot a = (B^2 + C^2 - A^2) / 4 * area
|
||||
We know that area = CH/2, and by the law of cosines we have
|
||||
|
||||
A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a
|
||||
|
||||
Putting these together, we get:
|
||||
|
||||
B^2 + C^2 - A^2 2BC cos a
|
||||
_______________ = _________ = (B/H) cos a = cos a / sin a = cot a
|
||||
4 * area 2CH
|
||||
|
||||
|
||||
[1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion
|
||||
and curvature flow", SIGGRAPH 1999.
|
||||
|
||||
[2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006.
|
||||
"""
|
||||
|
||||
if meshes.isempty():
|
||||
return torch.tensor(
|
||||
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
|
||||
)
|
||||
|
||||
N = len(meshes)
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
|
||||
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
|
||||
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
|
||||
weights = 1.0 / weights.float()
|
||||
|
||||
# We don't want to backprop through the computation of the Laplacian;
|
||||
# just treat it as a magic constant matrix that is used to transform
|
||||
# verts into normals
|
||||
with torch.no_grad():
|
||||
if method == "uniform":
|
||||
L = meshes.laplacian_packed()
|
||||
elif method in ["cot", "cotcurv"]:
|
||||
L, inv_areas = laplacian_cot(meshes)
|
||||
if method == "cot":
|
||||
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
|
||||
idx = norm_w > 0
|
||||
norm_w[idx] = 1.0 / norm_w[idx]
|
||||
else:
|
||||
norm_w = 0.25 * inv_areas
|
||||
else:
|
||||
raise ValueError("Method should be one of {uniform, cot, cotcurv}")
|
||||
|
||||
if method == "uniform":
|
||||
loss = L.mm(verts_packed)
|
||||
elif method == "cot":
|
||||
loss = L.mm(verts_packed) * norm_w - verts_packed
|
||||
elif method == "cotcurv":
|
||||
loss = (L.mm(verts_packed) - verts_packed) * norm_w
|
||||
loss = loss.norm(dim=1)
|
||||
|
||||
loss = loss * weights
|
||||
return loss.sum() / N
|
||||
|
||||
|
||||
def laplacian_cot(meshes):
|
||||
"""
|
||||
Returns the Laplacian matrix with cotangent weights and the inverse of the
|
||||
face areas.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object with a batch of meshes.
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
|
||||
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
|
||||
See the description above for more clarity.
|
||||
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
|
||||
face areas containing each vertex
|
||||
"""
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
|
||||
# V = sum(V_n), F = sum(F_n)
|
||||
V, F = verts_packed.shape[0], faces_packed.shape[0]
|
||||
|
||||
face_verts = verts_packed[faces_packed]
|
||||
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
|
||||
|
||||
# Side lengths of each triangle, of shape (sum(F_n),)
|
||||
# A is the side opposite v1, B is opposite v2, and C is opposite v3
|
||||
A = (v1 - v2).norm(dim=1)
|
||||
B = (v0 - v2).norm(dim=1)
|
||||
C = (v0 - v1).norm(dim=1)
|
||||
|
||||
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
|
||||
s = 0.5 * (A + B + C)
|
||||
# note that the area can be negative (close to 0) causing nans after sqrt()
|
||||
# we clip it to a small positive value
|
||||
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
|
||||
|
||||
# Compute cotangents of angles, of shape (sum(F_n), 3)
|
||||
A2, B2, C2 = A * A, B * B, C * C
|
||||
cota = (B2 + C2 - A2) / area
|
||||
cotb = (A2 + C2 - B2) / area
|
||||
cotc = (A2 + B2 - C2) / area
|
||||
cot = torch.stack([cota, cotb, cotc], dim=1)
|
||||
cot /= 4.0
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v1, v2] = cota
|
||||
# L[v2, v0] = cotb
|
||||
# L[v0, v1] = cotc
|
||||
ii = faces_packed[:, [1, 2, 0]]
|
||||
jj = faces_packed[:, [2, 0, 1]]
|
||||
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
|
||||
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
|
||||
|
||||
# Make it symmetric; this means we are also setting
|
||||
# L[v2, v1] = cota
|
||||
# L[v0, v2] = cotb
|
||||
# L[v1, v0] = cotc
|
||||
L += L.t()
|
||||
|
||||
# For each vertex, compute the sum of areas for triangles containing it.
|
||||
idx = faces_packed.view(-1)
|
||||
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
|
||||
val = torch.stack([area] * 3, dim=1).view(-1)
|
||||
inv_areas.scatter_add_(0, idx, val)
|
||||
idx = inv_areas > 0
|
||||
inv_areas[idx] = 1.0 / inv_areas[idx]
|
||||
inv_areas = inv_areas.view(-1, 1)
|
||||
|
||||
return L, inv_areas
|
||||
148
pytorch3d/loss/mesh_normal_consistency.py
Normal file
148
pytorch3d/loss/mesh_normal_consistency.py
Normal file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
|
||||
from itertools import islice
|
||||
import torch
|
||||
|
||||
|
||||
def mesh_normal_consistency(meshes):
|
||||
r"""
|
||||
Computes the normal consistency of each mesh in meshes.
|
||||
We compute the normal consistency for each pair of neighboring faces.
|
||||
If e = (v0, v1) is the connecting edge of two neighboring faces f0 and f1,
|
||||
then the normal consistency between f0 and f1
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
a
|
||||
/\
|
||||
/ \
|
||||
/ f0 \
|
||||
/ \
|
||||
v0 /____e___\ v1
|
||||
\ /
|
||||
\ /
|
||||
\ f1 /
|
||||
\ /
|
||||
\/
|
||||
b
|
||||
|
||||
The normal consistency is
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
nc(f0, f1) = 1 - cos(n0, n1)
|
||||
|
||||
where cos(n0, n1) = n0^n1 / ||n0|| / ||n1|| is the cosine of the angle
|
||||
between the normals n0 and n1, and
|
||||
|
||||
n0 = (v1 - v0) x (a - v0)
|
||||
n1 = - (v1 - v0) x (b - v0) = (b - v0) x (v1 - v0)
|
||||
|
||||
This means that if nc(f0, f1) = 0 then n0 and n1 point to the same
|
||||
direction, while if nc(f0, f1) = 2 then n0 and n1 point opposite direction.
|
||||
|
||||
.. note::
|
||||
For well-constructed meshes the assumption that only two faces share an
|
||||
edge is true. This assumption could make the implementation easier and faster.
|
||||
This implementation does not follow this assumption. All the faces sharing e,
|
||||
which can be any in number, are discovered.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object with a batch of meshes.
|
||||
|
||||
Returns:
|
||||
loss: Average normal consistency across the batch.
|
||||
Returns 0 if meshes contains no meshes or all empty meshes.
|
||||
"""
|
||||
if meshes.isempty():
|
||||
return torch.tensor(
|
||||
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
|
||||
)
|
||||
|
||||
N = len(meshes)
|
||||
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
||||
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
|
||||
edges_packed = meshes.edges_packed() # (sum(E_n), 2)
|
||||
verts_packed_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
|
||||
face_to_edge = meshes.faces_packed_to_edges_packed() # (sum(F_n), 3)
|
||||
E = edges_packed.shape[0] # sum(E_n)
|
||||
F = faces_packed.shape[0] # sum(F_n)
|
||||
|
||||
# We don't want gradients for the following operation. The goal is to
|
||||
# find for each edge e all the vertices associated with e. In the example above,
|
||||
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
|
||||
# and points connected on faces to e (=a, b).
|
||||
with torch.no_grad():
|
||||
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
|
||||
vert_idx = (
|
||||
faces_packed.view(1, F, 3)
|
||||
.expand(3, F, 3)
|
||||
.transpose(0, 1)
|
||||
.reshape(3 * F, 3)
|
||||
)
|
||||
edge_idx, edge_sort_idx = edge_idx.sort()
|
||||
vert_idx = vert_idx[edge_sort_idx]
|
||||
|
||||
# In well constructed meshes each edge is shared by precisely 2 faces
|
||||
# However, in many meshes, this assumption is not always satisfied.
|
||||
# We want to find all faces that share an edge, a number which can
|
||||
# vary and which depends on the topology.
|
||||
# In particular, we find the vertices not on the edge on the shared faces.
|
||||
# In the example above, we want to associate edge e with vertices a and b.
|
||||
# This operation is done more efficiently in cpu with lists.
|
||||
# TODO(gkioxari) find a better way to do this.
|
||||
|
||||
# edge_idx represents the index of the edge for each vertex. We can count
|
||||
# the number of vertices which are associated with each edge.
|
||||
# There can be a different number for each edge.
|
||||
edge_num = edge_idx.bincount(minlength=E)
|
||||
# Create pairs of vertices associated to e. We generate a list of lists:
|
||||
# each list has the indices of the vertices which are opposite to one edge.
|
||||
# The length of the list for each edge will vary.
|
||||
vert_edge_pair_idx = split_list(
|
||||
list(range(edge_idx.shape[0])), edge_num.tolist()
|
||||
)
|
||||
# For each list find all combinations of pairs in the list. This represents
|
||||
# all pairs of vertices which are opposite to the same edge.
|
||||
vert_edge_pair_idx = [
|
||||
[e[i], e[j]]
|
||||
for e in vert_edge_pair_idx
|
||||
for i in range(len(e) - 1)
|
||||
for j in range(1, len(e))
|
||||
if i != j
|
||||
]
|
||||
vert_edge_pair_idx = torch.tensor(
|
||||
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64
|
||||
)
|
||||
|
||||
v0_idx = edges_packed[edge_idx, 0]
|
||||
v0 = verts_packed[v0_idx]
|
||||
v1_idx = edges_packed[edge_idx, 1]
|
||||
v1 = verts_packed[v1_idx]
|
||||
|
||||
# two of the following cross products are zeros as they are cross product
|
||||
# with either (v1-v0)x(v1-v0) or (v1-v0)x(v0-v0)
|
||||
n_temp0 = (v1 - v0).cross(verts_packed[vert_idx[:, 0]] - v0, dim=1)
|
||||
n_temp1 = (v1 - v0).cross(verts_packed[vert_idx[:, 1]] - v0, dim=1)
|
||||
n_temp2 = (v1 - v0).cross(verts_packed[vert_idx[:, 2]] - v0, dim=1)
|
||||
n = n_temp0 + n_temp1 + n_temp2
|
||||
n0 = n[vert_edge_pair_idx[:, 0]]
|
||||
n1 = -n[vert_edge_pair_idx[:, 1]]
|
||||
loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
|
||||
|
||||
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
|
||||
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[
|
||||
vert_edge_pair_idx[:, 0]
|
||||
]
|
||||
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
|
||||
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()
|
||||
|
||||
loss = loss * weights
|
||||
return loss.sum() / N
|
||||
|
||||
|
||||
def split_list(input, length_to_split):
|
||||
inputt = iter(input)
|
||||
return [list(islice(inputt, elem)) for elem in length_to_split]
|
||||
Reference in New Issue
Block a user