fbcode/vision/fair/pytorch3d/pytorch3d/loss/point_mesh_distance.py

Reviewed By: bottler

Differential Revision: D93708351

fbshipit-source-id: 06a877777e4cb72a497a44ff55db0b6222bda83b
This commit is contained in:
generatedunixname1417043136753450
2026-02-22 06:55:36 -08:00
committed by meta-codesync[bot]
parent e9ed1cb178
commit 42d66c1145

View File

@@ -6,6 +6,7 @@
# pyre-unsafe
import torch
from pytorch3d import _C
from pytorch3d.structures import Meshes, Pointclouds
from torch.autograd import Function
@@ -302,8 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
weights_p = 1.0 / weights_p.float()
weights_p = torch.reciprocal(weights_p.float())
point_to_edge = point_to_edge * weights_p
point_dist = point_to_edge.sum() / N
@@ -377,8 +377,7 @@ def point_mesh_face_distance(
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
weights_p = 1.0 / weights_p.float()
weights_p = torch.reciprocal(weights_p.float())
point_to_face = point_to_face * weights_p
point_dist = point_to_face.sum() / N