mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 14:55:59 +08:00
fbcode/vision/fair/pytorch3d/pytorch3d/loss/point_mesh_distance.py
Reviewed By: bottler Differential Revision: D93708351 fbshipit-source-id: 06a877777e4cb72a497a44ff55db0b6222bda83b
This commit is contained in:
committed by
meta-codesync[bot]
parent
e9ed1cb178
commit
42d66c1145
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
|
import torch
|
||||||
from pytorch3d import _C
|
from pytorch3d import _C
|
||||||
from pytorch3d.structures import Meshes, Pointclouds
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
from torch.autograd import Function
|
from torch.autograd import Function
|
||||||
@@ -302,8 +303,7 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
|
|||||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
||||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
weights_p = torch.reciprocal(weights_p.float())
|
||||||
weights_p = 1.0 / weights_p.float()
|
|
||||||
point_to_edge = point_to_edge * weights_p
|
point_to_edge = point_to_edge * weights_p
|
||||||
point_dist = point_to_edge.sum() / N
|
point_dist = point_to_edge.sum() / N
|
||||||
|
|
||||||
@@ -377,8 +377,7 @@ def point_mesh_face_distance(
|
|||||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
||||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
weights_p = torch.reciprocal(weights_p.float())
|
||||||
weights_p = 1.0 / weights_p.float()
|
|
||||||
point_to_face = point_to_face * weights_p
|
point_to_face = point_to_face * weights_p
|
||||||
point_dist = point_to_face.sum() / N
|
point_dist = point_to_face.sum() / N
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user