mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-26 08:06:00 +08:00
Compare commits
5 Commits
3ba2030aa4
...
7b5c78460a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b5c78460a | ||
|
|
e3c80a4368 | ||
|
|
b9b5ea3428 | ||
|
|
0e435c297c | ||
|
|
d631b56fba |
@@ -82,10 +82,12 @@ class _SymEig3x3(nn.Module):
|
|||||||
q = inputs_trace / 3.0
|
q = inputs_trace / 3.0
|
||||||
|
|
||||||
# Calculate squared sum of elements outside the main diagonal / 2
|
# Calculate squared sum of elements outside the main diagonal / 2
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
p1 = (
|
||||||
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1)
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
) / 2
|
||||||
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp(
|
||||||
|
self._eps
|
||||||
|
)
|
||||||
|
|
||||||
p = torch.sqrt(p2 / 6.0)
|
p = torch.sqrt(p2 / 6.0)
|
||||||
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
|
||||||
@@ -104,7 +106,9 @@ class _SymEig3x3(nn.Module):
|
|||||||
# Soft dispatch between the degenerate case (diagonal A) and general.
|
# Soft dispatch between the degenerate case (diagonal A) and general.
|
||||||
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
|
||||||
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
# We use 6 * eps to take into account the error accumulated during the p1 summation
|
||||||
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
|
diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
|
||||||
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
|
||||||
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
|
||||||
@@ -199,8 +203,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
cross_products[..., :1, :]
|
cross_products[..., :1, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
norms_sq = torch.square(cross_products).sum(dim=-1)
|
||||||
norms_sq = (cross_products**2).sum(dim=-1)
|
|
||||||
max_norms_index = norms_sq.argmax(dim=-1)
|
max_norms_index = norms_sq.argmax(dim=-1)
|
||||||
|
|
||||||
# Pick only the cross-product with highest squared norm for each input
|
# Pick only the cross-product with highest squared norm for each input
|
||||||
|
|||||||
@@ -182,8 +182,7 @@ def iterative_closest_point(
|
|||||||
t_history.append(SimilarityTransform(R, T, s))
|
t_history.append(SimilarityTransform(R, T, s))
|
||||||
|
|
||||||
# compute the root mean squared error
|
# compute the root mean squared error
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
Xt_sq_diff = torch.square(Xt - Xt_nn_points).sum(2)
|
||||||
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
|
|
||||||
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
|
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
|
||||||
|
|
||||||
# compute the relative rmse
|
# compute the relative rmse
|
||||||
|
|||||||
@@ -179,9 +179,7 @@ def sample_farthest_points_naive(
|
|||||||
# and all the other points. If a point has already been selected
|
# and all the other points. If a point has already been selected
|
||||||
# it's distance will be 0.0 so it will not be selected again as the max.
|
# it's distance will be 0.0 so it will not be selected again as the max.
|
||||||
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
|
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
dist_to_last_selected = torch.square(dist).sum(-1) # (P - i)
|
||||||
# `int`.
|
|
||||||
dist_to_last_selected = (dist**2).sum(-1) # (P - i)
|
|
||||||
|
|
||||||
# If closer than currently saved distance to one of the selected
|
# If closer than currently saved distance to one of the selected
|
||||||
# points, then updated closest_dists
|
# points, then updated closest_dists
|
||||||
|
|||||||
@@ -132,15 +132,13 @@ def _get_splat_kernel_normalization(
|
|||||||
|
|
||||||
epsilon = 0.05
|
epsilon = 0.05
|
||||||
normalization_constant = torch.exp(
|
normalization_constant = torch.exp(
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
-torch.square(offsets).sum(dim=1) / (2 * sigma**2)
|
||||||
-(offsets**2).sum(dim=1) / (2 * sigma**2)
|
|
||||||
).sum()
|
).sum()
|
||||||
|
|
||||||
# We add an epsilon to the normalization constant to ensure the gradient will travel
|
# We add an epsilon to the normalization constant to ensure the gradient will travel
|
||||||
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
|
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
|
||||||
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
|
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
|
||||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
return torch.div(1 + epsilon, normalization_constant)
|
||||||
return (1 + epsilon) / normalization_constant
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_occlusion_layers(
|
def _compute_occlusion_layers(
|
||||||
@@ -264,8 +262,9 @@ def _compute_splatting_colors_and_weights(
|
|||||||
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
|
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
|
||||||
).view((N, H, W, K, 1, 2))
|
).view((N, H, W, K, 1, 2))
|
||||||
|
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
dist2_p_q = torch.sum(
|
||||||
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
|
torch.square(q_to_px_center + offsets), dim=5
|
||||||
|
) # (N, H, W, K, 9)
|
||||||
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
|
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
|
||||||
alpha = colors[..., 3:4]
|
alpha = colors[..., 3:4]
|
||||||
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
|
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
|
||||||
@@ -417,12 +416,12 @@ def _normalize_and_compose_all_layers(
|
|||||||
device = splatted_colors_per_occlusion_layer.device
|
device = splatted_colors_per_occlusion_layer.device
|
||||||
|
|
||||||
# Normalize each of bg/surface/fg splat layers separately.
|
# Normalize each of bg/surface/fg splat layers separately.
|
||||||
normalization_scales = 1.0 / (
|
normalization_scales = torch.div(
|
||||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
1.0,
|
||||||
torch.maximum(
|
torch.maximum(
|
||||||
splatted_weights_per_occlusion_layer,
|
splatted_weights_per_occlusion_layer,
|
||||||
torch.tensor([1.0], device=device),
|
torch.tensor([1.0], device=device),
|
||||||
)
|
),
|
||||||
) # (N, H, W, 1, 3)
|
) # (N, H, W, 1, 3)
|
||||||
|
|
||||||
normalized_splatted_colors = (
|
normalized_splatted_colors = (
|
||||||
|
|||||||
@@ -195,15 +195,15 @@ def _se3_V_matrix(
|
|||||||
V = (
|
V = (
|
||||||
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
||||||
+ log_rotation_hat
|
+ log_rotation_hat
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
* ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[
|
||||||
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
|
|
||||||
+ (
|
|
||||||
log_rotation_hat_square
|
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
||||||
# `int`.
|
|
||||||
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
|
|
||||||
:, None, None
|
:, None, None
|
||||||
]
|
]
|
||||||
|
+ (
|
||||||
|
log_rotation_hat_square
|
||||||
|
* (
|
||||||
|
(rotation_angles - torch.sin(rotation_angles))
|
||||||
|
/ torch.pow(rotation_angles, 3)
|
||||||
|
)[:, None, None]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
|
|||||||
A helper function that computes the input variables to the `_se3_V_matrix`
|
A helper function that computes the input variables to the `_se3_V_matrix`
|
||||||
function.
|
function.
|
||||||
"""
|
"""
|
||||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
nrms = torch.square(log_rotation).sum(-1)
|
||||||
nrms = (log_rotation**2).sum(-1)
|
|
||||||
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
||||||
log_rotation_hat = hat(log_rotation)
|
log_rotation_hat = hat(log_rotation)
|
||||||
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
|
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
|
||||||
|
|||||||
Reference in New Issue
Block a user