pytorch3d/pytorch3d/loss/mesh_edge_loss.py
Patrick Labatut 3c71ab64cc Remove shebang line when not strictly required
Summary: The shebang line `#!<path to interpreter>` is only required for Python scripts, so remove it on source files for class or function definitions. Additionally explicitly mark as executable the actual Python scripts in the codebase.

Reviewed By: nikhilaravi

Differential Revision: D20095778

fbshipit-source-id: d312599fba485e978a243292f88a180d71e1b55a
2020-03-12 10:39:44 -07:00

47 lines
1.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def mesh_edge_loss(meshes, target_length: float = 0.0):
"""
Computes mesh edge length regularization loss averaged across all meshes
in a batch. Each mesh contributes equally to the final loss, regardless of
the number of edges per mesh in the batch by weighting each mesh with the
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
contribute to the final loss.
Args:
meshes: Meshes object with a batch of meshes.
target_length: Resting value for the edge length.
Returns:
loss: Average loss across the batch. Returns 0 if meshes contains
no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
# TODO (nikhilar) Find a faster way of computing the weights for each edge
# as this is currently a bottleneck for meshes with a large number of faces.
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
weights = 1.0 / weights.float()
verts_edges = verts_packed[edges_packed]
v0, v1 = verts_edges.unbind(1)
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
loss = loss * weights
return loss.sum() / N