fbcode/vision/fair/pytorch3d/pytorch3d/renderer/splatter_blend.py

Reviewed By: sgrigory

Differential Revision: D93710022

fbshipit-source-id: 39253258b93a467fbda6b51ef8d6d3975bb49810
This commit is contained in:
generatedunixname1417043136753450
2026-02-23 12:43:53 -08:00
committed by meta-codesync[bot]
parent b9b5ea3428
commit e3c80a4368

View File

@@ -132,15 +132,13 @@ def _get_splat_kernel_normalization(
epsilon = 0.05 epsilon = 0.05
normalization_constant = torch.exp( normalization_constant = torch.exp(
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. -torch.square(offsets).sum(dim=1) / (2 * sigma**2)
-(offsets**2).sum(dim=1) / (2 * sigma**2)
).sum() ).sum()
# We add an epsilon to the normalization constant to ensure the gradient will travel # We add an epsilon to the normalization constant to ensure the gradient will travel
# through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia- # through non-boundary pixels' normalization factor, see Sec. 3.3.1 in "Differentia-
# ble Surface Rendering via Non-Differentiable Sampling", Cole et al. # ble Surface Rendering via Non-Differentiable Sampling", Cole et al.
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. return torch.div(1 + epsilon, normalization_constant)
return (1 + epsilon) / normalization_constant
def _compute_occlusion_layers( def _compute_occlusion_layers(
@@ -264,8 +262,9 @@ def _compute_splatting_colors_and_weights(
torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5 torch.floor(pixel_coords_screen[..., :2]) - pixel_coords_screen[..., :2] + 0.5
).view((N, H, W, K, 1, 2)) ).view((N, H, W, K, 1, 2))
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. dist2_p_q = torch.sum(
dist2_p_q = torch.sum((q_to_px_center + offsets) ** 2, dim=5) # (N, H, W, K, 9) torch.square(q_to_px_center + offsets), dim=5
) # (N, H, W, K, 9)
splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2)) splat_weights = torch.exp(-dist2_p_q / (2 * sigma**2))
alpha = colors[..., 3:4] alpha = colors[..., 3:4]
splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze( splat_weights = (alpha * splat_kernel_normalization * splat_weights).unsqueeze(
@@ -417,12 +416,12 @@ def _normalize_and_compose_all_layers(
device = splatted_colors_per_occlusion_layer.device device = splatted_colors_per_occlusion_layer.device
# Normalize each of bg/surface/fg splat layers separately. # Normalize each of bg/surface/fg splat layers separately.
normalization_scales = 1.0 / ( normalization_scales = torch.div(
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 1.0,
torch.maximum( torch.maximum(
splatted_weights_per_occlusion_layer, splatted_weights_per_occlusion_layer,
torch.tensor([1.0], device=device), torch.tensor([1.0], device=device),
) ),
) # (N, H, W, 1, 3) ) # (N, H, W, 1, 3)
normalized_splatted_colors = ( normalized_splatted_colors = (