mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-27 00:36:02 +08:00
fbcode/vision/fair/pytorch3d/pytorch3d/loss/point_mesh_distance.py
Reviewed By: bottler Differential Revision: D93708351 fbshipit-source-id: 06a877777e4cb72a497a44ff55db0b6222bda83b
This commit is contained in:
committed by
meta-codesync[bot]
parent
e9ed1cb178
commit
42d66c1145
@@ -6,6 +6,7 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from torch.autograd import Function
|
||||
@@ -302,8 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
|
||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
weights_p = 1.0 / weights_p.float()
|
||||
weights_p = torch.reciprocal(weights_p.float())
|
||||
point_to_edge = point_to_edge * weights_p
|
||||
point_dist = point_to_edge.sum() / N
|
||||
|
||||
@@ -377,8 +377,7 @@ def point_mesh_face_distance(
|
||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
weights_p = 1.0 / weights_p.float()
|
||||
weights_p = torch.reciprocal(weights_p.float())
|
||||
point_to_face = point_to_face * weights_p
|
||||
point_dist = point_to_face.sum() / N
|
||||
|
||||
|
||||
Reference in New Issue
Block a user