mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: License lint codebase Reviewed By: theschnitz Differential Revision: D29001799 fbshipit-source-id: 5c59869911785b0181b1663bbf430bc8b7fb2909
51 lines
1.9 KiB
Python
51 lines
1.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
|
|
def mesh_edge_loss(meshes, target_length: float = 0.0):
|
|
"""
|
|
Computes mesh edge length regularization loss averaged across all meshes
|
|
in a batch. Each mesh contributes equally to the final loss, regardless of
|
|
the number of edges per mesh in the batch by weighting each mesh with the
|
|
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
|
|
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
|
|
contribute to the final loss.
|
|
|
|
Args:
|
|
meshes: Meshes object with a batch of meshes.
|
|
target_length: Resting value for the edge length.
|
|
|
|
Returns:
|
|
loss: Average loss across the batch. Returns 0 if meshes contains
|
|
no meshes or all empty meshes.
|
|
"""
|
|
if meshes.isempty():
|
|
return torch.tensor(
|
|
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
|
|
)
|
|
|
|
N = len(meshes)
|
|
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
|
|
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
|
|
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
|
|
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
|
|
|
|
# Determine the weight for each edge based on the number of edges in the
|
|
# mesh it corresponds to.
|
|
# TODO (nikhilar) Find a faster way of computing the weights for each edge
|
|
# as this is currently a bottleneck for meshes with a large number of faces.
|
|
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
|
|
weights = 1.0 / weights.float()
|
|
|
|
verts_edges = verts_packed[edges_packed]
|
|
v0, v1 = verts_edges.unbind(1)
|
|
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
|
|
loss = loss * weights
|
|
|
|
return loss.sum() / N
|