5 Commits

Author SHA1 Message Date
generatedunixname1417043136753450
7b5c78460a fbcode/vision/fair/pytorch3d/pytorch3d/transforms/se3.py
Reviewed By: sgrigory

Differential Revision: D93709801

fbshipit-source-id: e4bae81fe1a88fed547304e6e21b248c5a345277
2026-02-23 14:51:32 -08:00
generatedunixname1417043136753450
e3c80a4368 fbcode/vision/fair/pytorch3d/pytorch3d/renderer/splatter_blend.py
Reviewed By: sgrigory

Differential Revision: D93710022

fbshipit-source-id: 39253258b93a467fbda6b51ef8d6d3975bb49810
2026-02-23 12:43:53 -08:00
generatedunixname1417043136753450
b9b5ea3428 fbcode/vision/fair/pytorch3d/pytorch3d/common/workaround/symeig3x3.py
Reviewed By: sgrigory

Differential Revision: D93715209

fbshipit-source-id: 1880a8dd72e35ce5cc93cdeecf770aab6469ca31
2026-02-23 12:42:24 -08:00
generatedunixname1417043136753450
0e435c297c fbcode/vision/fair/pytorch3d/pytorch3d/ops/points_alignment.py
Reviewed By: sgrigory

Differential Revision: D93712744

fbshipit-source-id: 660560cdef9ff1d2173ae06de54df31766ee537f
2026-02-23 12:28:37 -08:00
generatedunixname1417043136753450
d631b56fba fbcode/vision/fair/pytorch3d/pytorch3d/ops/sample_farthest_points.py
Reviewed By: sgrigory

Differential Revision: D93708653

fbshipit-source-id: 112158092cd64ac8afddf1378b931cb44e19c372
2026-02-23 10:21:52 -08:00
5 changed files with 28 additions and 30 deletions

View File

@@ -82,10 +82,12 @@ class _SymEig3x3(nn.Module):
q = inputs_trace / 3.0
# Calculate squared sum of elements outside the main diagonal / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
p1 = (
torch.square(inputs).sum(dim=(-1, -2)) - torch.square(inputs_diag).sum(-1)
) / 2
p2 = torch.square(inputs_diag - q[..., None]).sum(dim=-1) + 2.0 * p1.clamp(
self._eps
)
p = torch.sqrt(p2 / 6.0)
B = (inputs - q[..., None, None] * self._identity) / p[..., None, None]
@@ -104,7 +106,9 @@ class _SymEig3x3(nn.Module):
# Soft dispatch between the degenerate case (diagonal A) and general.
# diag_soft_cond -> 1.0 when p1 < 6 * eps and diag_soft_cond -> 0.0 otherwise.
# We use 6 * eps to take into account the error accumulated during the p1 summation
diag_soft_cond = torch.exp(-((p1 / (6 * self._eps)) ** 2)).detach()[..., None]
diag_soft_cond = torch.exp(-torch.square(p1 / (6 * self._eps))).detach()[
..., None
]
# Eigenvalues are the ordered elements of main diagonal in the degenerate case
diag_eigenvals, _ = torch.sort(inputs_diag, dim=-1)
@@ -199,8 +203,7 @@ class _SymEig3x3(nn.Module):
cross_products[..., :1, :]
)
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
norms_sq = (cross_products**2).sum(dim=-1)
norms_sq = torch.square(cross_products).sum(dim=-1)
max_norms_index = norms_sq.argmax(dim=-1)
# Pick only the cross-product with highest squared norm for each input

View File

@@ -182,8 +182,7 @@ def iterative_closest_point(
t_history.append(SimilarityTransform(R, T, s))
# compute the root mean squared error
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
Xt_sq_diff = torch.square(Xt - Xt_nn_points).sum(2)
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
# compute the relative rmse

View File

@@ -179,9 +179,7 @@ def sample_farthest_points_naive(
# and all the other points. If a point has already been selected
# it's distance will be 0.0 so it will not be selected again as the max.
dist = points[n, selected_idx, :] - points[n, : lengths[n], :]
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
dist_to_last_selected = (dist**2).sum(-1) # (P - i)
dist_to_last_selected = torch.square(dist).sum(-1) # (P - i)
# If closer than currently saved distance to one of the selected
# points, then updated closest_dists

View File

@@ -132,15 +132,13 @@ def _get_splat_kernel_normalization(
epsilon = 0.05
normalization_constant = torch.exp(
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
-(offsets**2).sum(dim=1) / (2 * sigma**2)
-torch.square(offsets).sum(dim=1) / (2 * sigma**2)
).sum()
# We add an epsilon to the normalization constant to ensure the gradient will travel
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
return (1 + epsilon) / normalization_constant
return torch.div(1 + epsilon, normalization_constant)
def _compute_occlusion_layers(
@@ -264,8 +262,9 @@ def _compute_splatting_colors_and_weights(
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
).view((N, H, W, K, 1, 2))
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
dist2_p_q = torch.sum(
torch.square(q_to_px_center + offsets), dim=5
) # (N, H, W, K, 9)
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
alpha = colors[..., 3:4]
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
@@ -417,12 +416,12 @@ def _normalize_and_compose_all_layers(
device = splatted_colors_per_occlusion_layer.device
# Normalize each of bg/surface/fg splat layers separately.
normalization_scales = 1.0 / (
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
normalization_scales = torch.div(
1.0,
torch.maximum(
splatted_weights_per_occlusion_layer,
torch.tensor([1.0], device=device),
)
),
) # (N, H, W, 1, 3)
normalized_splatted_colors = (

View File

@@ -195,15 +195,15 @@ def _se3_V_matrix(
V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
* ((1 - torch.cos(rotation_angles)) / torch.square(rotation_angles))[
:, None, None
]
+ (
log_rotation_hat_square
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
:, None, None
]
* (
(rotation_angles - torch.sin(rotation_angles))
/ torch.pow(rotation_angles, 3)
)[:, None, None]
)
)
@@ -215,8 +215,7 @@ def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
A helper function that computes the input variables to the `_se3_V_matrix`
function.
"""
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
nrms = (log_rotation**2).sum(-1)
nrms = torch.square(log_rotation).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation)
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)