3 Commits

Author SHA1 Message Date
Jeremy Reizenstein
61cc79aa34 Make _sqrt_positive_part ONNX-exportable
Summary:
Replace boolean indexing and torch.is_grad_enabled() control flow in _sqrt_positive_part with a pure torch.where implementation. The old code used ret[positive_mask] = torch.sqrt(x[positive_mask]) which produces an incorrect ONNX Where/index_put node with mismatched broadcast shapes when the model is exported via torch.onnx.export.

The new implementation substitutes 1.0 for non-positive values before sqrt (avoiding infinite gradient at sqrt(0)) and masks the result back to 0, preserving the zero-subgradient-at-zero property.

Fixes https://github.com/facebookresearch/pytorch3d/issues/2020

Reviewed By: sgrigory

Differential Revision: D94365479

fbshipit-source-id: a1ebe8dc077573f83efc262520b6669159b83ef0
2026-03-06 05:23:55 -08:00
generatedunixname2645487282517272
7a6157e38e Fix CQS signal modernize-use-using in fbcode/vision/fair
Reviewed By: bottler

Differential Revision: D94879733

fbshipit-source-id: fc35eaaa723a2a035b3b204732add7ba8b225c57
2026-03-02 05:59:34 -08:00
generatedunixname1417043136753450
d9839a95f2 fbcode/vision/fair/pytorch3d/pytorch3d/ops/cameras_alignment.py
Reviewed By: sgrigory

Differential Revision: D93710806

fbshipit-source-id: da6c1e1e5b7a1c5cdfbf5026993c42c7ec387415
2026-02-23 15:52:03 -08:00
4 changed files with 5 additions and 10 deletions

View File

@@ -19,7 +19,7 @@ template <
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec2 {
T x, y;
typedef T scalar_t;
using scalar_t = T;
vec2(T x, T y) : x(x), y(y) {}
};

View File

@@ -18,7 +18,7 @@ template <
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec3 {
T x, y, z;
typedef T scalar_t;
using scalar_t = T;
vec3(T x, T y, T z) : x(x), y(y), z(z) {}
};

View File

@@ -223,8 +223,7 @@ def _align_camera_extrinsics(
# of centered A and centered B
Ac = A - Amu
Bc = B - Bmu
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps)
align_t_s = (Ac * Bc).mean() / torch.square(Ac).mean().clamp(eps)
else:
# set the scale to identity
align_t_s = 1.0

View File

@@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
safe_x = torch.where(positive_mask, x, 1.0)
return torch.where(positive_mask, torch.sqrt(safe_x), 0.0)
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: