mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-25 23:56:00 +08:00
Compare commits
5 Commits
3ba2030aa4
...
7b5c78460a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b5c78460a | ||
|
|
e3c80a4368 | ||
|
|
b9b5ea3428 | ||
|
|
0e435c297c | ||
|
|
d631b56fba |
@@ -82,10 +82,12 @@ class _SymEig3x3(nn.Module):
|
||||
q = inputs_trace / 3.0
|
||||
|
||||
# Calculate squared sum of elements outside the main diagonal / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||
p1 = (
|
||||
torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1)
|
||||
) / 2
|
||||
p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp(
|
||||
self._eps
|
||||
)
|
||||
|
||||
p = torch.sqrt(p2 / 6.0)
|
||||
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
||||
@@ -104,7 +106,9 @@ class _SymEig3x3(nn.Module):
|
||||
# Soft dispatch between the degenerate case (diagonal A) and general.
|
||||
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
||||
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
||||
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
|
||||
diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[
|
||||
..., None
|
||||
]
|
||||
|
||||
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
||||
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
||||
@@ -199,8 +203,7 @@ class _SymEig3x3(nn.Module):
|
||||
cross_products[..., :1, :]
|
||||
)
|
||||
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
norms_sq = (cross_products**2).sum(dim=-1)
|
||||
norms_sq = torch.square(cross_products).sum(dim=-1)
|
||||
max_norms_index = norms_sq.argmax(dim=-1)
|
||||
|
||||
# Pick only the cross-product with highest squared norm for each input
|
||||
|
||||
@@ -182,8 +182,7 @@ def iterative_closest_point(
|
||||
t_history.append(SimilarityTransform(R, T, s))
|
||||
|
||||
# compute the root mean squared error
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
|
||||
Xt_sq_diff = torch.square(Xt - Xt_nn_points).sum(2)
|
||||
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
|
||||
|
||||
# compute the relative rmse
|
||||
|
||||
@@ -179,9 +179,7 @@ def sample_farthest_points_naive(
|
||||
# and all the other points. If a point has already been selected
|
||||
# it's distance will be 0.0 so it will not be selected again as the max.
|
||||
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
dist_to_last_selected = (dist**2).sum(-1) # (P - i)
|
||||
dist_to_last_selected = torch.square(dist).sum(-1) # (P - i)
|
||||
|
||||
# If closer than currently saved distance to one of the selected
|
||||
# points, then updated closest_dists
|
||||
|
||||
@@ -132,15 +132,13 @@ def _get_splat_kernel_normalization(
|
||||
|
||||
epsilon = 0.05
|
||||
normalization_constant = torch.exp(
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
-(offsets**2).sum(dim=1) / (2 * sigma**2)
|
||||
-torch.square(offsets).sum(dim=1) / (2 * sigma**2)
|
||||
).sum()
|
||||
|
||||
# We add an epsilon to the normalization constant to ensure the gradient will travel
|
||||
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
|
||||
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
return (1 + epsilon) / normalization_constant
|
||||
return torch.div(1 + epsilon, normalization_constant)
|
||||
|
||||
|
||||
def _compute_occlusion_layers(
|
||||
@@ -264,8 +262,9 @@ def _compute_splatting_colors_and_weights(
|
||||
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
|
||||
).view((N, H, W, K, 1, 2))
|
||||
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
|
||||
dist2_p_q = torch.sum(
|
||||
torch.square(q_to_px_center + offsets), dim=5
|
||||
) # (N, H, W, K, 9)
|
||||
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
|
||||
alpha = colors[..., 3:4]
|
||||
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
|
||||
@@ -417,12 +416,12 @@ def _normalize_and_compose_all_layers(
|
||||
device = splatted_colors_per_occlusion_layer.device
|
||||
|
||||
# Normalize each of bg/surface/fg splat layers separately.
|
||||
normalization_scales = 1.0 / (
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
normalization_scales = torch.div(
|
||||
1.0,
|
||||
torch.maximum(
|
||||
splatted_weights_per_occlusion_layer,
|
||||
torch.tensor([1.0], device=device),
|
||||
)
|
||||
),
|
||||
) # (N, H, W, 1, 3)
|
||||
|
||||
normalized_splatted_colors = (
|
||||
|
||||
@@ -195,15 +195,15 @@ def _se3_V_matrix(
|
||||
V = (
|
||||
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
||||
+ log_rotation_hat
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
|
||||
* ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[
|
||||
:, None, None
|
||||
]
|
||||
+ (
|
||||
log_rotation_hat_square
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
|
||||
:, None, None
|
||||
]
|
||||
* (
|
||||
(rotation_angles - torch.sin(rotation_angles))
|
||||
/ torch.pow(rotation_angles, 3)
|
||||
)[:, None, None]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
|
||||
A helper function that computes the input variables to the `_se3_V_matrix`
|
||||
function.
|
||||
"""
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
nrms = (log_rotation**2).sum(-1)
|
||||
nrms = torch.square(log_rotation).sum(-1)
|
||||
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
||||
log_rotation_hat = hat(log_rotation)
|
||||
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
|
||||
|
||||
Reference in New Issue
Block a user