In blending, pull common functionality into get_background_color

Summary: A small refactor, originally intended for use with the splatter.

Reviewed By: bottler

Differential Revision: D36210393

fbshipit-source-id: b3372f7cc7690ee45dd3059b2d4be1c8dfa63180
This commit is contained in:
Krzysztof Chalupka 2022-05-16 18:23:51 -07:00 committed by Facebook GitHub Bot
parent 4372001981
commit ea5df60d72

View File

@ -9,6 +9,7 @@ from typing import NamedTuple, Sequence, Union
import torch
from pytorch3d import _C
from pytorch3d.common.datatypes import Device
# Example functions for blending the top K colors per pixel using the outputs
@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
def _get_background_color(
blend_params: BlendParams, device: Device, dtype=torch.float32
) -> torch.Tensor:
background_color_ = blend_params.background_color
if isinstance(background_color_, torch.Tensor):
background_color = background_color_.to(device)
else:
background_color = torch.tensor(background_color_, dtype=dtype, device=device)
return background_color
def hard_rgb_blend(
colors: torch.Tensor, fragments, blend_params: BlendParams
) -> torch.Tensor:
@ -57,18 +69,11 @@ def hard_rgb_blend(
Returns:
RGBA pixel_colors: (N, H, W, 4)
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
# Mask for the background.
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
background_color_ = blend_params.background_color
if isinstance(background_color_, torch.Tensor):
background_color = background_color_.to(device)
else:
background_color = colors.new_tensor(background_color_)
# Find out how much background_color needs to be expanded to be used for masked_scatter.
num_background_pixels = is_background.sum()
@ -182,13 +187,8 @@ def softmax_rgb_blend(
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
background_ = blend_params.background_color
if not isinstance(background_, torch.Tensor):
background = torch.tensor(background_, dtype=torch.float32, device=device)
else:
background = background_.to(device)
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
# Weight for background color
eps = 1e-10
@ -233,7 +233,7 @@ def softmax_rgb_blend(
# Sum: weights * textures + background color
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
weighted_background = delta * background
weighted_background = delta * background_color
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
pixel_colors[..., 3] = 1.0 - alpha