mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
point mesh distances
Summary: Implementation of point to mesh distances. The current diff contains two types: (a) Point to Edge (b) Point to Face ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_EDGE_4_100_300_5000_cuda:0 2745 3138 183 POINT_MESH_EDGE_4_100_300_10000_cuda:0 4408 4499 114 POINT_MESH_EDGE_4_100_3000_5000_cuda:0 4978 5070 101 POINT_MESH_EDGE_4_100_3000_10000_cuda:0 9076 9187 56 POINT_MESH_EDGE_4_1000_300_5000_cuda:0 1411 1487 355 POINT_MESH_EDGE_4_1000_300_10000_cuda:0 4829 5030 104 POINT_MESH_EDGE_4_1000_3000_5000_cuda:0 7539 7620 67 POINT_MESH_EDGE_4_1000_3000_10000_cuda:0 12088 12272 42 POINT_MESH_EDGE_8_100_300_5000_cuda:0 3106 3222 161 POINT_MESH_EDGE_8_100_300_10000_cuda:0 8561 8648 59 POINT_MESH_EDGE_8_100_3000_5000_cuda:0 6932 7021 73 POINT_MESH_EDGE_8_100_3000_10000_cuda:0 24032 24176 21 POINT_MESH_EDGE_8_1000_300_5000_cuda:0 5272 5399 95 POINT_MESH_EDGE_8_1000_300_10000_cuda:0 11348 11430 45 POINT_MESH_EDGE_8_1000_3000_5000_cuda:0 17478 17683 29 POINT_MESH_EDGE_8_1000_3000_10000_cuda:0 25961 26236 20 POINT_MESH_EDGE_16_100_300_5000_cuda:0 8244 8323 61 POINT_MESH_EDGE_16_100_300_10000_cuda:0 18018 18071 28 POINT_MESH_EDGE_16_100_3000_5000_cuda:0 19428 19544 26 POINT_MESH_EDGE_16_100_3000_10000_cuda:0 44967 45135 12 POINT_MESH_EDGE_16_1000_300_5000_cuda:0 7825 7937 64 POINT_MESH_EDGE_16_1000_300_10000_cuda:0 18504 18571 28 POINT_MESH_EDGE_16_1000_3000_5000_cuda:0 65805 66132 8 POINT_MESH_EDGE_16_1000_3000_10000_cuda:0 90885 91089 6 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_FACE_4_100_300_5000_cuda:0 1561 1685 321 POINT_MESH_FACE_4_100_300_10000_cuda:0 2818 2954 178 POINT_MESH_FACE_4_100_3000_5000_cuda:0 15893 16018 32 POINT_MESH_FACE_4_100_3000_10000_cuda:0 16350 16439 31 POINT_MESH_FACE_4_1000_300_5000_cuda:0 3179 3278 158 POINT_MESH_FACE_4_1000_300_10000_cuda:0 2353 2436 213 POINT_MESH_FACE_4_1000_3000_5000_cuda:0 16262 16336 31 POINT_MESH_FACE_4_1000_3000_10000_cuda:0 9334 9448 54 POINT_MESH_FACE_8_100_300_5000_cuda:0 4377 4493 115 POINT_MESH_FACE_8_100_300_10000_cuda:0 9728 9822 52 POINT_MESH_FACE_8_100_3000_5000_cuda:0 26428 26544 19 POINT_MESH_FACE_8_100_3000_10000_cuda:0 42238 43031 12 POINT_MESH_FACE_8_1000_300_5000_cuda:0 3891 3982 129 POINT_MESH_FACE_8_1000_300_10000_cuda:0 5363 5429 94 POINT_MESH_FACE_8_1000_3000_5000_cuda:0 20998 21084 24 POINT_MESH_FACE_8_1000_3000_10000_cuda:0 39711 39897 13 POINT_MESH_FACE_16_100_300_5000_cuda:0 5955 6001 84 POINT_MESH_FACE_16_100_300_10000_cuda:0 12082 12144 42 POINT_MESH_FACE_16_100_3000_5000_cuda:0 44996 45176 12 POINT_MESH_FACE_16_100_3000_10000_cuda:0 73042 73197 7 POINT_MESH_FACE_16_1000_300_5000_cuda:0 8292 8374 61 POINT_MESH_FACE_16_1000_300_10000_cuda:0 19442 19506 26 POINT_MESH_FACE_16_1000_3000_5000_cuda:0 36059 36194 14 POINT_MESH_FACE_16_1000_3000_10000_cuda:0 64644 64822 8 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20590462 fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
This commit is contained in:
committed by
Facebook GitHub Bot
parent
474c8b456a
commit
487d4d6607
@@ -5,6 +5,7 @@ from .chamfer import chamfer_distance
|
||||
from .mesh_edge_loss import mesh_edge_loss
|
||||
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
|
||||
from .mesh_normal_consistency import mesh_normal_consistency
|
||||
from .point_mesh_distance import point_mesh_edge_distance, point_mesh_face_distance
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
351
pytorch3d/loss/point_mesh_distance.py
Normal file
351
pytorch3d/loss/point_mesh_distance.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
|
||||
"""
|
||||
This file defines distances between meshes and pointclouds.
|
||||
The functions make use of the definition of a distance between a point and
|
||||
an edge segment or the distance of a point and a triangle (face).
|
||||
|
||||
The exact mathematical formulations and implementations of these
|
||||
distances can be found in `csrc/utils/geometry_utils.cuh`.
|
||||
"""
|
||||
|
||||
|
||||
# PointFaceDistance
|
||||
class _PointFaceDistance(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper PointFaceDistance Cuda implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points, points_first_idx, tris, tris_first_idx, max_points):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
points: FloatTensor of shape `(P, 3)`
|
||||
points_first_idx: LongTensor of shape `(N,)` indicating the first point
|
||||
index in each example in the batch
|
||||
tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
|
||||
triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
|
||||
tris_first_idx: LongTensor of shape `(N,)` indicating the first face
|
||||
index in each example in the batch
|
||||
max_points: Scalar equal to maximum number of points in the batch
|
||||
Returns:
|
||||
dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
|
||||
euclidean distance of `p`-th point to the closest triangular face
|
||||
in the corresponding example in the batch
|
||||
idxs: LongTensor of shape `(P,)` indicating the closest triangular face
|
||||
in the corresponindg example in the batch.
|
||||
|
||||
`dists[p] = d(points[p], tris[idxs[p], 0], tris[idxs[p], 1], tris[idxs[p], 2])`
|
||||
where `d(u, v0, v1, v2)` is the distance of point `u` from the trianfular face `(v0, v1, v2)`
|
||||
|
||||
"""
|
||||
dists, idxs = _C.point_face_dist_forward(
|
||||
points, points_first_idx, tris, tris_first_idx, max_points
|
||||
)
|
||||
ctx.save_for_backward(points, tris, idxs)
|
||||
return dists
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists):
|
||||
grad_dists = grad_dists.contiguous()
|
||||
points, tris, idxs = ctx.saved_tensors
|
||||
grad_points, grad_tris = _C.point_face_dist_backward(
|
||||
points, tris, idxs, grad_dists
|
||||
)
|
||||
return grad_points, None, grad_tris, None, None
|
||||
|
||||
|
||||
point_face_distance = _PointFaceDistance.apply
|
||||
|
||||
|
||||
# FacePointDistance
|
||||
class _FacePointDistance(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper FacePointDistance Cuda implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points, points_first_idx, tris, tris_first_idx, max_tris):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
points: FloatTensor of shape `(P, 3)`
|
||||
points_first_idx: LongTensor of shape `(N,)` indicating the first point
|
||||
index in each example in the batch
|
||||
tris: FloatTensor of shape `(T, 3, 3)` of triangular faces. The `t`-th
|
||||
triangular face is spanned by `(tris[t, 0], tris[t, 1], tris[t, 2])`
|
||||
tris_first_idx: LongTensor of shape `(N,)` indicating the first face
|
||||
index in each example in the batch
|
||||
max_tris: Scalar equal to maximum number of faces in the batch
|
||||
Returns:
|
||||
dists: FloatTensor of shape `(T,)`, where `dists[t]` is the squared
|
||||
euclidean distance of `t`-th trianguar face to the closest point in the
|
||||
corresponding example in the batch
|
||||
idxs: LongTensor of shape `(T,)` indicating the closest point in the
|
||||
corresponindg example in the batch.
|
||||
|
||||
`dists[t] = d(points[idxs[t]], tris[t, 0], tris[t, 1], tris[t, 2])`,
|
||||
where `d(u, v0, v1, v2)` is the distance of point `u` from the triangular
|
||||
face `(v0, v1, v2)`.
|
||||
"""
|
||||
dists, idxs = _C.face_point_dist_forward(
|
||||
points, points_first_idx, tris, tris_first_idx, max_tris
|
||||
)
|
||||
ctx.save_for_backward(points, tris, idxs)
|
||||
return dists
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists):
|
||||
grad_dists = grad_dists.contiguous()
|
||||
points, tris, idxs = ctx.saved_tensors
|
||||
grad_points, grad_tris = _C.face_point_dist_backward(
|
||||
points, tris, idxs, grad_dists
|
||||
)
|
||||
return grad_points, None, grad_tris, None, None
|
||||
|
||||
|
||||
face_point_distance = _FacePointDistance.apply
|
||||
|
||||
|
||||
# PointEdgeDistance
|
||||
class _PointEdgeDistance(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper PointEdgeDistance Cuda implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points, points_first_idx, segms, segms_first_idx, max_points):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
points: FloatTensor of shape `(P, 3)`
|
||||
points_first_idx: LongTensor of shape `(N,)` indicating the first point
|
||||
index for each example in the mesh
|
||||
segms: FloatTensor of shape `(S, 2, 3)` of edge segments. The `s`-th
|
||||
edge segment is spanned by `(segms[s, 0], segms[s, 1])`
|
||||
segms_first_idx: LongTensor of shape `(N,)` indicating the first edge
|
||||
index for each example in the mesh
|
||||
max_points: Scalar equal to maximum number of points in the batch
|
||||
Returns:
|
||||
dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
|
||||
euclidean distance of `p`-th point to the closest edge in the
|
||||
corresponding example in the batch
|
||||
idxs: LongTensor of shape `(P,)` indicating the closest edge in the
|
||||
corresponindg example in the batch.
|
||||
|
||||
`dists[p] = d(points[p], segms[idxs[p], 0], segms[idxs[p], 1])`,
|
||||
where `d(u, v0, v1)` is the distance of point `u` from the edge segment
|
||||
spanned by `(v0, v1)`.
|
||||
"""
|
||||
dists, idxs = _C.point_edge_dist_forward(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points
|
||||
)
|
||||
ctx.save_for_backward(points, segms, idxs)
|
||||
return dists
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists):
|
||||
grad_dists = grad_dists.contiguous()
|
||||
points, segms, idxs = ctx.saved_tensors
|
||||
grad_points, grad_segms = _C.point_edge_dist_backward(
|
||||
points, segms, idxs, grad_dists
|
||||
)
|
||||
return grad_points, None, grad_segms, None, None
|
||||
|
||||
|
||||
point_edge_distance = _PointEdgeDistance.apply
|
||||
|
||||
|
||||
# EdgePointDistance
|
||||
class _EdgePointDistance(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper EdgePointDistance Cuda implementation
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, points, points_first_idx, segms, segms_first_idx, max_segms):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
points: FloatTensor of shape `(P, 3)`
|
||||
points_first_idx: LongTensor of shape `(N,)` indicating the first point
|
||||
index for each example in the mesh
|
||||
segms: FloatTensor of shape `(S, 2, 3)` of edge segments. The `s`-th
|
||||
edge segment is spanned by `(segms[s, 0], segms[s, 1])`
|
||||
segms_first_idx: LongTensor of shape `(N,)` indicating the first edge
|
||||
index for each example in the mesh
|
||||
max_segms: Scalar equal to maximum number of edges in the batch
|
||||
Returns:
|
||||
dists: FloatTensor of shape `(S,)`, where `dists[s]` is the squared
|
||||
euclidean distance of `s`-th edge to the closest point in the
|
||||
corresponding example in the batch
|
||||
idxs: LongTensor of shape `(S,)` indicating the closest point in the
|
||||
corresponindg example in the batch.
|
||||
|
||||
`dists[s] = d(points[idxs[s]], edges[s, 0], edges[s, 1])`,
|
||||
where `d(u, v0, v1)` is the distance of point `u` from the segment
|
||||
spanned by `(v0, v1)`.
|
||||
"""
|
||||
dists, idxs = _C.edge_point_dist_forward(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms
|
||||
)
|
||||
ctx.save_for_backward(points, segms, idxs)
|
||||
return dists
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists):
|
||||
grad_dists = grad_dists.contiguous()
|
||||
points, segms, idxs = ctx.saved_tensors
|
||||
grad_points, grad_segms = _C.edge_point_dist_backward(
|
||||
points, segms, idxs, grad_dists
|
||||
)
|
||||
return grad_points, None, grad_segms, None, None
|
||||
|
||||
|
||||
edge_point_distance = _EdgePointDistance.apply
|
||||
|
||||
|
||||
def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
|
||||
"""
|
||||
Computes the distance between a pointcloud and a mesh within a batch.
|
||||
Given a pair `(mesh, pcl)` in the batch, we define the distance to be the
|
||||
sum of two distances, namely `point_edge(mesh, pcl) + edge_point(mesh, pcl)`
|
||||
|
||||
`point_edge(mesh, pcl)`: Computes the squared distance of each point p in pcl
|
||||
to the closest edge segment in mesh and averages across all points in pcl
|
||||
`edge_point(mesh, pcl)`: Computes the squared distance of each edge segment in mesh
|
||||
to the closest point in pcl and averages across all edges in mesh.
|
||||
|
||||
The above distance functions are applied for all `(mesh, pcl)` pairs in the batch and
|
||||
then averaged across the batch.
|
||||
|
||||
Args:
|
||||
meshes: A Meshes data structure containing N meshes
|
||||
pcls: A Pointclouds data structure containing N pointclouds
|
||||
|
||||
Returns:
|
||||
loss: The `point_edge(mesh, pcl) + edge_point(mesh, pcl)` distance
|
||||
between all `(mesh, pcl)` in a batch averaged across the batch.
|
||||
"""
|
||||
if len(meshes) != len(pcls):
|
||||
raise ValueError("meshes and pointclouds be equal sized batches")
|
||||
N = len(meshes)
|
||||
|
||||
# packed representation for pointclouds
|
||||
points = pcls.points_packed() # (P, 3)
|
||||
points_first_idx = pcls.cloud_to_packed_first_idx()
|
||||
max_points = pcls.num_points_per_cloud().max().item()
|
||||
|
||||
# packed representation for edges
|
||||
verts_packed = meshes.verts_packed()
|
||||
edges_packed = meshes.edges_packed()
|
||||
segms = verts_packed[edges_packed] # (S, 2, 3)
|
||||
segms_first_idx = meshes.mesh_to_edges_packed_first_idx()
|
||||
max_segms = meshes.num_edges_per_mesh().max().item()
|
||||
|
||||
# point to edge distance: shape (P,)
|
||||
point_to_edge = point_edge_distance(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points
|
||||
)
|
||||
|
||||
# weigh each example by the inverse of number of points in the example
|
||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i), )
|
||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||
weights_p = 1.0 / weights_p.float()
|
||||
point_to_edge = point_to_edge * weights_p
|
||||
point_dist = point_to_edge.sum() / N
|
||||
|
||||
# edge to edge distance: shape (S,)
|
||||
edge_to_point = edge_point_distance(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms
|
||||
)
|
||||
|
||||
# weigh each example by the inverse of number of edges in the example
|
||||
segm_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(S_n),)
|
||||
num_segms_per_mesh = meshes.num_edges_per_mesh() # (N,)
|
||||
weights_s = num_segms_per_mesh.gather(0, segm_to_mesh_idx)
|
||||
weights_s = 1.0 / weights_s.float()
|
||||
edge_to_point = edge_to_point * weights_s
|
||||
edge_dist = edge_to_point.sum() / N
|
||||
|
||||
return point_dist + edge_dist
|
||||
|
||||
|
||||
def point_mesh_face_distance(meshes: Meshes, pcls: Pointclouds):
|
||||
"""
|
||||
Computes the distance between a pointcloud and a mesh within a batch.
|
||||
Given a pair `(mesh, pcl)` in the batch, we define the distance to be the
|
||||
sum of two distances, namely `point_face(mesh, pcl) + face_point(mesh, pcl)`
|
||||
|
||||
`point_face(mesh, pcl)`: Computes the squared distance of each point p in pcl
|
||||
to the closest triangular face in mesh and averages across all points in pcl
|
||||
`face_point(mesh, pcl)`: Computes the squared distance of each triangular face in mesh
|
||||
to the closest point in pcl and averages across all faces in mesh.
|
||||
|
||||
The above distance functions are applied for all `(mesh, pcl)` pairs in the batch and
|
||||
then averaged across the batch.
|
||||
|
||||
Args:
|
||||
meshes: A Meshes data structure containing N meshes
|
||||
pcls: A Pointclouds data structure containing N pointclouds
|
||||
|
||||
Returns:
|
||||
loss: The `point_face(mesh, pcl) + face_point(mesh, pcl)` distance
|
||||
between all `(mesh, pcl)` in a batch averaged across the batch.
|
||||
"""
|
||||
|
||||
if len(meshes) != len(pcls):
|
||||
raise ValueError("meshes and pointclouds must be equal sized batches")
|
||||
N = len(meshes)
|
||||
|
||||
# packed representation for pointclouds
|
||||
points = pcls.points_packed() # (P, 3)
|
||||
points_first_idx = pcls.cloud_to_packed_first_idx()
|
||||
max_points = pcls.num_points_per_cloud().max().item()
|
||||
|
||||
# packed representation for faces
|
||||
verts_packed = meshes.verts_packed()
|
||||
faces_packed = meshes.faces_packed()
|
||||
tris = verts_packed[faces_packed] # (T, 3, 3)
|
||||
tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_tris = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
# point to face distance: shape (P,)
|
||||
point_to_face = point_face_distance(
|
||||
points, points_first_idx, tris, tris_first_idx, max_points
|
||||
)
|
||||
|
||||
# weigh each example by the inverse of number of points in the example
|
||||
point_to_cloud_idx = pcls.packed_to_cloud_idx() # (sum(P_i),)
|
||||
num_points_per_cloud = pcls.num_points_per_cloud() # (N,)
|
||||
weights_p = num_points_per_cloud.gather(0, point_to_cloud_idx)
|
||||
weights_p = 1.0 / weights_p.float()
|
||||
point_to_face = point_to_face * weights_p
|
||||
point_dist = point_to_face.sum() / N
|
||||
|
||||
# face to point distance: shape (T,)
|
||||
face_to_point = face_point_distance(
|
||||
points, points_first_idx, tris, tris_first_idx, max_tris
|
||||
)
|
||||
|
||||
# weigh each example by the inverse of number of faces in the example
|
||||
tri_to_mesh_idx = meshes.faces_packed_to_mesh_idx() # (sum(T_n),)
|
||||
num_tris_per_mesh = meshes.num_faces_per_mesh() # (N, )
|
||||
weights_t = num_tris_per_mesh.gather(0, tri_to_mesh_idx)
|
||||
weights_t = 1.0 / weights_t.float()
|
||||
face_to_point = face_to_point * weights_t
|
||||
face_dist = face_to_point.sum() / N
|
||||
|
||||
return point_dist + face_dist
|
||||
Reference in New Issue
Block a user