fbcode/vision/fair/pytorch3d/pytorch3d/renderer/splatter_blend.py

Reviewed By: sgrigory

Differential Revision: D93710022

fbshipit-source-id: 39253258b93a467fbda6b51ef8d6d3975bb49810
This commit is contained in:
generatedunixname1417043136753450
2026-02-23 12:43:53 -08:00
committed by meta-codesync[bot]
parent b9b5ea3428
commit e3c80a4368

View File

@@ -132,15 +132,13 @@ def _get_splat_kernel_normalization(
epsilon = 0.05
normalization_constant = torch.exp(
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
-(offsets**2).sum(dim=1) / (2 * sigma**2)
-torch.square(offsets).sum(dim=1) / (2 * sigma**2)
).sum()
# We add an epsilon to the normalization constant to ensure the gradient will travel
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
return (1 + epsilon) / normalization_constant
return torch.div(1 + epsilon, normalization_constant)
def _compute_occlusion_layers(
@@ -264,8 +262,9 @@ def _compute_splatting_colors_and_weights(
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
).view((N, H, W, K, 1, 2))
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9)
dist2_p_q = torch.sum(
torch.square(q_to_px_center + offsets), dim=5
) # (N, H, W, K, 9)
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
alpha = colors[..., 3:4]
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
@@ -417,12 +416,12 @@ def _normalize_and_compose_all_layers(
device = splatted_colors_per_occlusion_layer.device
# Normalize each of bg/surface/fg splat layers separately.
normalization_scales = 1.0 / (
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
normalization_scales = torch.div(
1.0,
torch.maximum(
splatted_weights_per_occlusion_layer,
torch.tensor([1.0], device=device),
)
),
) # (N, H, W, 1, 3)
normalized_splatted_colors = (