Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .chamfer import chamfer_distance
from .mesh_edge_loss import mesh_edge_loss
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
from .mesh_normal_consistency import mesh_normal_consistency
__all__ = [k for k in globals().keys() if not k.startswith("_")]

152
pytorch3d/loss/chamfer.py Normal file
View File

@@ -0,0 +1,152 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
def _validate_chamfer_reduction_inputs(
batch_reduction: str, point_reduction: str
):
"""Check the requested reductions are valid.
Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
"""
if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'batch_reduction must be one of ["none", "mean", "sum"]'
)
if point_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'point_reduction must be one of ["none", "mean", "sum"]'
)
if batch_reduction == "none" and point_reduction == "none":
raise ValueError(
'batch_reduction and point_reduction cannot both be "none".'
)
def chamfer_distance(
x,
y,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: str = "mean",
point_reduction: str = "mean",
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
with P1 points in each batch element, batch size N and feature
dimension D.
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
with P2 points in each batch element, batch size N and feature
dimension D.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
N, P1, D = x.shape
P2 = y.shape[1]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights can not be nonnegative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return (
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
if return_normals:
cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
)
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
if point_reduction != "none":
# If not 'none' then either 'sum' or 'mean'.
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
cham_x /= P1
cham_y /= P2
if return_normals:
cham_norm_x /= P1
cham_norm_y /= P2
if batch_reduction != "none":
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
cham_norm_y = cham_norm_y.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else N
cham_x /= div
cham_y /= div
if return_normals:
cham_norm_x /= div
cham_norm_y /= div
cham_dist = cham_x + cham_y
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
return cham_dist, cham_normals

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def mesh_edge_loss(meshes, target_length: float = 0.0):
"""
Computes mesh edge length regularization loss averaged across all meshes
in a batch. Each edge contributes equally to the final loss, regardless of
numbers of edges per mesh in the batch by weighting each mesh with the
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
contribute to the final loss.
Args:
meshes: Meshes object with a batch of meshes.
target_length: Resting value for the edge length.
Returns:
loss: Average loss across the batch. Returns 0 if meshes contains
no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
# TODO (nikhilar) Find a faster way of computing the weights for each edge
# as this is currently a bottleneck for meshes with a large number of faces.
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
weights = 1.0 / weights.float()
verts_edges = verts_packed[edges_packed]
v0, v1 = verts_edges.unbind(1)
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
loss = loss * weights
return loss.sum() / N

View File

@@ -0,0 +1,195 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
r"""
Computes the laplacian smoothing objective for a batch of meshes.
This function supports three variants of Laplacian smoothing,
namely with uniform weights("uniform"), with cotangent weights ("cot"),
and cotangent cuvature ("cotcurv").For more details read [1, 2].
Args:
meshes: Meshes object with a batch of meshes.
method: str specifying the method for the laplacian.
Returns:
loss: Average laplacian smoothing loss across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3.
The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors:
for a uniform Laplacian, LuV[i] points to the centroid of its neighboring
vertices, a cotangent Laplacian LcV[i] is known to be an approximation of
the surface normal, while the curvature variant LckV[i] scales the normals
by the discrete mean curvature. For vertex i, assume S[i] is the set of
neighboring vertices to i, a_ij and b_ij are the "outside" angles in the
two triangles connecting vertex v_i and its neighboring vertex v_j
for j in S[i], as seen in the diagram below.
.. code-block:: python
a_ij
/\
/ \
/ \
/ \
v_i /________\ v_j
\ /
\ /
\ /
\ /
\/
b_ij
The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i)
For the uniform variant, w_ij = 1 / |S[i]|
For the cotangent variant,
w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik)
For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i])
where A[i] is the sum of the areas of all triangles containing vertex v_i.
There is a nice trigonometry identity to compute cotangents. Consider a triangle
with side lengths A, B, C and angles a, b, c.
.. code-block:: python
c
/|\
/ | \
/ | \
B / H| \ A
/ | \
/ | \
/a_____|_____b\
C
Then cot a = (B^2 + C^2 - A^2) / 4 * area
We know that area = CH/2, and by the law of cosines we have
A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a
Putting these together, we get:
B^2 + C^2 - A^2 2BC cos a
_______________ = _________ = (B/H) cos a = cos a / sin a = cot a
4 * area 2CH
[1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion
and curvature flow", SIGGRAPH 1999.
[2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
weights = 1.0 / weights.float()
# We don't want to backprop through the computation of the Laplacian;
# just treat it as a magic constant matrix that is used to transform
# verts into normals
with torch.no_grad():
if method == "uniform":
L = meshes.laplacian_packed()
elif method in ["cot", "cotcurv"]:
L, inv_areas = laplacian_cot(meshes)
if method == "cot":
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
idx = norm_w > 0
norm_w[idx] = 1.0 / norm_w[idx]
else:
norm_w = 0.25 * inv_areas
else:
raise ValueError("Method should be one of {uniform, cot, cotcurv}")
if method == "uniform":
loss = L.mm(verts_packed)
elif method == "cot":
loss = L.mm(verts_packed) * norm_w - verts_packed
elif method == "cotcurv":
loss = (L.mm(verts_packed) - verts_packed) * norm_w
loss = loss.norm(dim=1)
loss = loss * weights
return loss.sum() / N
def laplacian_cot(meshes):
"""
Returns the Laplacian matrix with cotangent weights and the inverse of the
face areas.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
2-element tuple containing
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
See the description above for more clarity.
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
face areas containing each vertex
"""
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
# V = sum(V_n), F = sum(F_n)
V, F = verts_packed.shape[0], faces_packed.shape[0]
face_verts = verts_packed[faces_packed]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Side lengths of each triangle, of shape (sum(F_n),)
# A is the side opposite v1, B is opposite v2, and C is opposite v3
A = (v1 - v2).norm(dim=1)
B = (v0 - v2).norm(dim=1)
C = (v0 - v1).norm(dim=1)
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
s = 0.5 * (A + B + C)
# note that the area can be negative (close to 0) causing nans after sqrt()
# we clip it to a small positive value
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
# Compute cotangents of angles, of shape (sum(F_n), 3)
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0
# Construct a sparse matrix by basically doing:
# L[v1, v2] = cota
# L[v2, v0] = cotb
# L[v0, v1] = cotc
ii = faces_packed[:, [1, 2, 0]]
jj = faces_packed[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
# L[v0, v2] = cotb
# L[v1, v0] = cotc
L += L.t()
# For each vertex, compute the sum of areas for triangles containing it.
idx = faces_packed.view(-1)
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
val = torch.stack([area] * 3, dim=1).view(-1)
inv_areas.scatter_add_(0, idx, val)
idx = inv_areas > 0
inv_areas[idx] = 1.0 / inv_areas[idx]
inv_areas = inv_areas.view(-1, 1)
return L, inv_areas

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from itertools import islice
import torch
def mesh_normal_consistency(meshes):
r"""
Computes the normal consistency of each mesh in meshes.
We compute the normal consistency for each pair of neighboring faces.
If e = (v0, v1) is the connecting edge of two neighboring faces f0 and f1,
then the normal consistency between f0 and f1
.. code-block:: python
a
/\
/ \
/ f0 \
/ \
v0 /____e___\ v1
\ /
\ /
\ f1 /
\ /
\/
b
The normal consistency is
.. code-block:: python
nc(f0, f1) = 1 - cos(n0, n1)
where cos(n0, n1) = n0^n1 / ||n0|| / ||n1|| is the cosine of the angle
between the normals n0 and n1, and
n0 = (v1 - v0) x (a - v0)
n1 = - (v1 - v0) x (b - v0) = (b - v0) x (v1 - v0)
This means that if nc(f0, f1) = 0 then n0 and n1 point to the same
direction, while if nc(f0, f1) = 2 then n0 and n1 point opposite direction.
.. note::
For well-constructed meshes the assumption that only two faces share an
edge is true. This assumption could make the implementation easier and faster.
This implementation does not follow this assumption. All the faces sharing e,
which can be any in number, are discovered.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
loss: Average normal consistency across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
edges_packed = meshes.edges_packed() # (sum(E_n), 2)
verts_packed_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
face_to_edge = meshes.faces_packed_to_edges_packed() # (sum(F_n), 3)
E = edges_packed.shape[0] # sum(E_n)
F = faces_packed.shape[0] # sum(F_n)
# We don't want gradients for the following operation. The goal is to
# find for each edge e all the vertices associated with e. In the example above,
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
# and points connected on faces to e (=a, b).
with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = (
faces_packed.view(1, F, 3)
.expand(3, F, 3)
.transpose(0, 1)
.reshape(3 * F, 3)
)
edge_idx, edge_sort_idx = edge_idx.sort()
vert_idx = vert_idx[edge_sort_idx]
# In well constructed meshes each edge is shared by precisely 2 faces
# However, in many meshes, this assumption is not always satisfied.
# We want to find all faces that share an edge, a number which can
# vary and which depends on the topology.
# In particular, we find the vertices not on the edge on the shared faces.
# In the example above, we want to associate edge e with vertices a and b.
# This operation is done more efficiently in cpu with lists.
# TODO(gkioxari) find a better way to do this.
# edge_idx represents the index of the edge for each vertex. We can count
# the number of vertices which are associated with each edge.
# There can be a different number for each edge.
edge_num = edge_idx.bincount(minlength=E)
# Create pairs of vertices associated to e. We generate a list of lists:
# each list has the indices of the vertices which are opposite to one edge.
# The length of the list for each edge will vary.
vert_edge_pair_idx = split_list(
list(range(edge_idx.shape[0])), edge_num.tolist()
)
# For each list find all combinations of pairs in the list. This represents
# all pairs of vertices which are opposite to the same edge.
vert_edge_pair_idx = [
[e[i], e[j]]
for e in vert_edge_pair_idx
for i in range(len(e) - 1)
for j in range(1, len(e))
if i != j
]
vert_edge_pair_idx = torch.tensor(
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64
)
v0_idx = edges_packed[edge_idx, 0]
v0 = verts_packed[v0_idx]
v1_idx = edges_packed[edge_idx, 1]
v1 = verts_packed[v1_idx]
# two of the following cross products are zeros as they are cross product
# with either (v1-v0)x(v1-v0) or (v1-v0)x(v0-v0)
n_temp0 = (v1 - v0).cross(verts_packed[vert_idx[:, 0]] - v0, dim=1)
n_temp1 = (v1 - v0).cross(verts_packed[vert_idx[:, 1]] - v0, dim=1)
n_temp2 = (v1 - v0).cross(verts_packed[vert_idx[:, 2]] - v0, dim=1)
n = n_temp0 + n_temp1 + n_temp2
n0 = n[vert_edge_pair_idx[:, 0]]
n1 = -n[vert_edge_pair_idx[:, 1]]
loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[
vert_edge_pair_idx[:, 0]
]
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()
loss = loss * weights
return loss.sum() / N
def split_list(input, length_to_split):
inputt = iter(input)
return [list(islice(inputt, elem)) for elem in length_to_split]