From ea5df60d72307378d4c0641519e4e7a3671458dc Mon Sep 17 00:00:00 2001 From: Krzysztof Chalupka Date: Mon, 16 May 2022 18:23:51 -0700 Subject: [PATCH] In blending, pull common functionality into get_background_color Summary: A small refactor, originally intended for use with the splatter. Reviewed By: bottler Differential Revision: D36210393 fbshipit-source-id: b3372f7cc7690ee45dd3059b2d4be1c8dfa63180 --- pytorch3d/renderer/blending.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index cdfb4e60..bfdae6c9 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -9,6 +9,7 @@ from typing import NamedTuple, Sequence, Union import torch from pytorch3d import _C +from pytorch3d.common.datatypes import Device # Example functions for blending the top K colors per pixel using the outputs @@ -37,6 +38,17 @@ class BlendParams(NamedTuple): background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0) +def _get_background_color( + blend_params: BlendParams, device: Device, dtype=torch.float32 +) -> torch.Tensor: + background_color_ = blend_params.background_color + if isinstance(background_color_, torch.Tensor): + background_color = background_color_.to(device) + else: + background_color = torch.tensor(background_color_, dtype=dtype, device=device) + return background_color + + def hard_rgb_blend( colors: torch.Tensor, fragments, blend_params: BlendParams ) -> torch.Tensor: @@ -57,18 +69,11 @@ def hard_rgb_blend( Returns: RGBA pixel_colors: (N, H, W, 4) """ - N, H, W, K = fragments.pix_to_face.shape - device = fragments.pix_to_face.device + background_color = _get_background_color(blend_params, fragments.pix_to_face.device) # Mask for the background. is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) - background_color_ = blend_params.background_color - if isinstance(background_color_, torch.Tensor): - background_color = background_color_.to(device) - else: - background_color = colors.new_tensor(background_color_) - # Find out how much background_color needs to be expanded to be used for masked_scatter. num_background_pixels = is_background.sum() @@ -182,13 +187,8 @@ def softmax_rgb_blend( """ N, H, W, K = fragments.pix_to_face.shape - device = fragments.pix_to_face.device pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) - background_ = blend_params.background_color - if not isinstance(background_, torch.Tensor): - background = torch.tensor(background_, dtype=torch.float32, device=device) - else: - background = background_.to(device) + background_color = _get_background_color(blend_params, fragments.pix_to_face.device) # Weight for background color eps = 1e-10 @@ -233,7 +233,7 @@ def softmax_rgb_blend( # Sum: weights * textures + background color weighted_colors = (weights_num[..., None] * colors).sum(dim=-2) - weighted_background = delta * background + weighted_background = delta * background_color pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom pixel_colors[..., 3] = 1.0 - alpha