3 Commits

Author SHA1 Message Date
Jeremy Reizenstein
61cc79aa34 Make _sqrt_positive_part ONNX-exportable
Summary:
Replace boolean indexing and torch.is_grad_enabled() control flow in _sqrt_positive_part with a pure torch.where implementation. The old code used ret[positive_mask] = torch.sqrt(x[positive_mask]) which produces an incorrect ONNX Where/index_put node with mismatched broadcast shapes when the model is exported via torch.onnx.export.

The new implementation substitutes 1.0 for non-positive values before sqrt (avoiding infinite gradient at sqrt(0)) and masks the result back to 0, preserving the zero-subgradient-at-zero property.

Fixes https://github.com/facebookresearch/pytorch3d/issues/2020

Reviewed By: sgrigory

Differential Revision: D94365479

fbshipit-source-id: a1ebe8dc077573f83efc262520b6669159b83ef0
2026-03-06 05:23:55 -08:00
generatedunixname2645487282517272
7a6157e38e Fix CQS signal modernize-use-using in fbcode/vision/fair
Reviewed By: bottler

Differential Revision: D94879733

fbshipit-source-id: fc35eaaa723a2a035b3b204732add7ba8b225c57
2026-03-02 05:59:34 -08:00
generatedunixname1417043136753450
d9839a95f2 fbcode/vision/fair/pytorch3d/pytorch3d/ops/cameras_alignment.py
Reviewed By: sgrigory

Differential Revision: D93710806

fbshipit-source-id: da6c1e1e5b7a1c5cdfbf5026993c42c7ec387415
2026-02-23 15:52:03 -08:00
4 changed files with 5 additions and 10 deletions

View File

@@ -19,7 +19,7 @@ template <
std::is_same<T, double>::value || std::is_same<T, float>::value>> std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec2 { struct vec2 {
T x, y; T x, y;
typedef T scalar_t; using scalar_t = T;
vec2(T x, T y) : x(x), y(y) {} vec2(T x, T y) : x(x), y(y) {}
}; };

View File

@@ -18,7 +18,7 @@ template <
std::is_same<T, double>::value || std::is_same<T, float>::value>> std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec3 { struct vec3 {
T x, y, z; T x, y, z;
typedef T scalar_t; using scalar_t = T;
vec3(T x, T y, T z) : x(x), y(y), z(z) {} vec3(T x, T y, T z) : x(x), y(y), z(z) {}
}; };

View File

@@ -223,8 +223,7 @@ def _align_camera_extrinsics(
# of centered A and centered B # of centered A and centered B
Ac = A - Amu Ac = A - Amu
Bc = B - Bmu Bc = B - Bmu
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. align_t_s = (Ac * Bc).mean() / torch.square(Ac).mean().clamp(eps)
align_t_s = (Ac * Bc).mean() / (Ac**2).mean().clamp(eps)
else: else:
# set the scale to identity # set the scale to identity
align_t_s = 1.0 align_t_s = 1.0

View File

@@ -94,13 +94,9 @@ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
Returns torch.sqrt(torch.max(0, x)) Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0. but with a zero subgradient where x is 0.
""" """
ret = torch.zeros_like(x)
positive_mask = x > 0 positive_mask = x > 0
if torch.is_grad_enabled(): safe_x = torch.where(positive_mask, x, 1.0)
ret[positive_mask] = torch.sqrt(x[positive_mask]) return torch.where(positive_mask, torch.sqrt(safe_x), 0.0)
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: