mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 11:50:35 +08:00
Single directional chamfer distance and non-absolute cosine similarity
Summary: Single directional chamfer distance and option to use non-absolute cosine similarity Reviewed By: bottler Differential Revision: D46593980 fbshipit-source-id: b2e591706a0cdde1c2d361614cecebb84a581433
This commit is contained in:
committed by
Facebook GitHub Bot
parent
573a42cd5f
commit
5ffeb4d580
@@ -68,6 +68,94 @@ def _handle_pointcloud_input(
|
||||
return X, lengths, normals
|
||||
|
||||
|
||||
def _chamfer_distance_single_direction(
|
||||
x,
|
||||
y,
|
||||
x_lengths,
|
||||
y_lengths,
|
||||
x_normals,
|
||||
y_normals,
|
||||
weights,
|
||||
batch_reduction: Union[str, None],
|
||||
point_reduction: str,
|
||||
norm: int,
|
||||
abs_cosine: bool,
|
||||
):
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
N, P1, D = x.shape
|
||||
|
||||
# Check if inputs are heterogeneous and create a lengths mask.
|
||||
is_x_heterogeneous = (x_lengths != P1).any()
|
||||
x_mask = (
|
||||
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||
) # shape [N, P1]
|
||||
if y.shape[0] != N or y.shape[2] != D:
|
||||
raise ValueError("y does not have the correct shape.")
|
||||
if weights is not None:
|
||||
if weights.size(0) != N:
|
||||
raise ValueError("weights must be of shape (N,).")
|
||||
if not (weights >= 0).all():
|
||||
raise ValueError("weights cannot be negative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
return (
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
)
|
||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||
|
||||
cham_norm_x = x.new_zeros(())
|
||||
|
||||
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
|
||||
cham_x = x_nn.dists[..., 0] # (N, P1)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_x[x_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_x *= weights.view(N, 1)
|
||||
|
||||
if return_normals:
|
||||
# Gather the normals using the indices and keep only value for k=0
|
||||
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
|
||||
|
||||
cosine_sim = F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||
# If abs_cosine, ignore orientation and take the absolute value of the cosine sim.
|
||||
cham_norm_x = 1 - (torch.abs(cosine_sim) if abs_cosine else cosine_sim)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_norm_x[x_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
x_lengths_clamped = x_lengths.clamp(min=1)
|
||||
cham_x /= x_lengths_clamped
|
||||
if return_normals:
|
||||
cham_norm_x /= x_lengths_clamped
|
||||
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else max(N, 1)
|
||||
cham_x /= div
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
|
||||
cham_dist = cham_x
|
||||
cham_normals = cham_norm_x if return_normals else None
|
||||
return cham_dist, cham_normals
|
||||
|
||||
|
||||
def chamfer_distance(
|
||||
x,
|
||||
y,
|
||||
@@ -79,6 +167,8 @@ def chamfer_distance(
|
||||
batch_reduction: Union[str, None] = "mean",
|
||||
point_reduction: str = "mean",
|
||||
norm: int = 2,
|
||||
single_directional: bool = False,
|
||||
abs_cosine: bool = True,
|
||||
):
|
||||
"""
|
||||
Chamfer distance between two pointclouds x and y.
|
||||
@@ -103,6 +193,14 @@ def chamfer_distance(
|
||||
point_reduction: Reduction operation to apply for the loss across the
|
||||
points, can be one of ["mean", "sum"].
|
||||
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
|
||||
single_directional: If False (default), loss comes from both the distance between
|
||||
each point in x and its nearest neighbor in y and each point in y and its nearest
|
||||
neighbor in x. If True, loss is the distance between each point in x and its
|
||||
nearest neighbor in y.
|
||||
abs_cosine: If False, loss_normals is from one minus the cosine similarity.
|
||||
If True (default), loss_normals is from one minus the absolute value of the
|
||||
cosine similarity, which means that exactly opposite normals are considered
|
||||
equivalent to exactly matching normals, i.e. sign does not matter.
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
@@ -112,116 +210,45 @@ def chamfer_distance(
|
||||
- **loss_normals**: Tensor giving the reduced cosine distance of normals
|
||||
between pointclouds in x and pointclouds in y. Returns None if
|
||||
x_normals and y_normals are None.
|
||||
|
||||
"""
|
||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||
|
||||
if not ((norm == 1) or (norm == 2)):
|
||||
raise ValueError("Support for 1 or 2 norm.")
|
||||
|
||||
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||
|
||||
return_normals = x_normals is not None and y_normals is not None
|
||||
|
||||
N, P1, D = x.shape
|
||||
P2 = y.shape[1]
|
||||
|
||||
# Check if inputs are heterogeneous and create a lengths mask.
|
||||
is_x_heterogeneous = (x_lengths != P1).any()
|
||||
is_y_heterogeneous = (y_lengths != P2).any()
|
||||
x_mask = (
|
||||
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
|
||||
) # shape [N, P1]
|
||||
y_mask = (
|
||||
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
|
||||
) # shape [N, P2]
|
||||
|
||||
if y.shape[0] != N or y.shape[2] != D:
|
||||
raise ValueError("y does not have the correct shape.")
|
||||
if weights is not None:
|
||||
if weights.size(0) != N:
|
||||
raise ValueError("weights must be of shape (N,).")
|
||||
if not (weights >= 0).all():
|
||||
raise ValueError("weights cannot be negative.")
|
||||
if weights.sum() == 0.0:
|
||||
weights = weights.view(N, 1)
|
||||
if batch_reduction in ["mean", "sum"]:
|
||||
return (
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
(x.sum((1, 2)) * weights).sum() * 0.0,
|
||||
)
|
||||
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
|
||||
|
||||
cham_norm_x = x.new_zeros(())
|
||||
cham_norm_y = x.new_zeros(())
|
||||
|
||||
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
|
||||
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
|
||||
|
||||
cham_x = x_nn.dists[..., 0] # (N, P1)
|
||||
cham_y = y_nn.dists[..., 0] # (N, P2)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_x *= weights.view(N, 1)
|
||||
cham_y *= weights.view(N, 1)
|
||||
|
||||
if return_normals:
|
||||
# Gather the normals using the indices and keep only value for k=0
|
||||
x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :]
|
||||
y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :]
|
||||
|
||||
cham_norm_x = 1 - torch.abs(
|
||||
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
|
||||
cham_x, cham_norm_x = _chamfer_distance_single_direction(
|
||||
x,
|
||||
y,
|
||||
x_lengths,
|
||||
y_lengths,
|
||||
x_normals,
|
||||
y_normals,
|
||||
weights,
|
||||
batch_reduction,
|
||||
point_reduction,
|
||||
norm,
|
||||
abs_cosine,
|
||||
)
|
||||
if single_directional:
|
||||
return cham_x, cham_norm_x
|
||||
else:
|
||||
cham_y, cham_norm_y = _chamfer_distance_single_direction(
|
||||
y,
|
||||
x,
|
||||
y_lengths,
|
||||
x_lengths,
|
||||
y_normals,
|
||||
x_normals,
|
||||
weights,
|
||||
batch_reduction,
|
||||
point_reduction,
|
||||
norm,
|
||||
abs_cosine,
|
||||
)
|
||||
cham_norm_y = 1 - torch.abs(
|
||||
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
|
||||
return (
|
||||
cham_x + cham_y,
|
||||
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
|
||||
)
|
||||
|
||||
if is_x_heterogeneous:
|
||||
cham_norm_x[x_mask] = 0.0
|
||||
if is_y_heterogeneous:
|
||||
cham_norm_y[y_mask] = 0.0
|
||||
|
||||
if weights is not None:
|
||||
cham_norm_x *= weights.view(N, 1)
|
||||
cham_norm_y *= weights.view(N, 1)
|
||||
|
||||
# Apply point reduction
|
||||
cham_x = cham_x.sum(1) # (N,)
|
||||
cham_y = cham_y.sum(1) # (N,)
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum(1) # (N,)
|
||||
cham_norm_y = cham_norm_y.sum(1) # (N,)
|
||||
if point_reduction == "mean":
|
||||
x_lengths_clamped = x_lengths.clamp(min=1)
|
||||
y_lengths_clamped = y_lengths.clamp(min=1)
|
||||
cham_x /= x_lengths_clamped
|
||||
cham_y /= y_lengths_clamped
|
||||
if return_normals:
|
||||
cham_norm_x /= x_lengths_clamped
|
||||
cham_norm_y /= y_lengths_clamped
|
||||
|
||||
if batch_reduction is not None:
|
||||
# batch_reduction == "sum"
|
||||
cham_x = cham_x.sum()
|
||||
cham_y = cham_y.sum()
|
||||
if return_normals:
|
||||
cham_norm_x = cham_norm_x.sum()
|
||||
cham_norm_y = cham_norm_y.sum()
|
||||
if batch_reduction == "mean":
|
||||
div = weights.sum() if weights is not None else max(N, 1)
|
||||
cham_x /= div
|
||||
cham_y /= div
|
||||
if return_normals:
|
||||
cham_norm_x /= div
|
||||
cham_norm_y /= div
|
||||
|
||||
cham_dist = cham_x + cham_y
|
||||
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
|
||||
|
||||
return cham_dist, cham_normals
|
||||
|
||||
Reference in New Issue
Block a user