diff --git a/pytorch3d/loss/point_mesh_distance.py b/pytorch3d/loss/point_mesh_distance.py index b30cd05e..32bb71dd 100644 --- a/pytorch3d/loss/point_mesh_distance.py +++ b/pytorch3d/loss/point_mesh_distance.py @@ -6,6 +6,7 @@ # pyre-unsafe +import torch from pytorch3d import _C from pytorch3d.structures import Meshes, Pointclouds from torch.autograd import Function @@ -302,8 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds): point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), ) num_points_per_cloud = pcls.num_points_per_cloud() # (N,) weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) - # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. - weights_p = 1.0 / weights_p.float() + weights_p = torch.reciprocal(weights_p.float()) point_to_edge = point_to_edge * weights_p point_dist = point_to_edge.sum() / N @@ -377,8 +377,7 @@ def point_mesh_face_distance( point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),) num_points_per_cloud = pcls.num_points_per_cloud() # (N,) weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx) - # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. - weights_p = 1.0 / weights_p.float() + weights_p = torch.reciprocal(weights_p.float()) point_to_face = point_to_face * weights_p point_dist = point_to_face.sum() / N