diff --git a/pytorch3d/loss/mesh_edge_loss.py b/pytorch3d/loss/mesh_edge_loss.py index 77eaa9dc..00d235ad 100644 --- a/pytorch3d/loss/mesh_edge_loss.py +++ b/pytorch3d/loss/mesh_edge_loss.py @@ -7,8 +7,8 @@ import torch def mesh_edge_loss(meshes, target_length: float = 0.0): """ Computes mesh edge length regularization loss averaged across all meshes - in a batch. Each edge contributes equally to the final loss, regardless of - numbers of edges per mesh in the batch by weighting each mesh with the + in a batch. Each mesh contributes equally to the final loss, regardless of + the number of edges per mesh in the batch by weighting each mesh with the inverse number of edges. For example, if mesh 3 (out of N) has only E=4 edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to contribute to the final loss.