mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
In blending, pull common functionality into get_background_color
Summary: A small refactor, originally intended for use with the splatter. Reviewed By: bottler Differential Revision: D36210393 fbshipit-source-id: b3372f7cc7690ee45dd3059b2d4be1c8dfa63180
This commit is contained in:
parent
4372001981
commit
ea5df60d72
@ -9,6 +9,7 @@ from typing import NamedTuple, Sequence, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.common.datatypes import Device
|
||||
|
||||
|
||||
# Example functions for blending the top K colors per pixel using the outputs
|
||||
@ -37,6 +38,17 @@ class BlendParams(NamedTuple):
|
||||
background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
|
||||
|
||||
|
||||
def _get_background_color(
|
||||
blend_params: BlendParams, device: Device, dtype=torch.float32
|
||||
) -> torch.Tensor:
|
||||
background_color_ = blend_params.background_color
|
||||
if isinstance(background_color_, torch.Tensor):
|
||||
background_color = background_color_.to(device)
|
||||
else:
|
||||
background_color = torch.tensor(background_color_, dtype=dtype, device=device)
|
||||
return background_color
|
||||
|
||||
|
||||
def hard_rgb_blend(
|
||||
colors: torch.Tensor, fragments, blend_params: BlendParams
|
||||
) -> torch.Tensor:
|
||||
@ -57,18 +69,11 @@ def hard_rgb_blend(
|
||||
Returns:
|
||||
RGBA pixel_colors: (N, H, W, 4)
|
||||
"""
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
|
||||
|
||||
# Mask for the background.
|
||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||
|
||||
background_color_ = blend_params.background_color
|
||||
if isinstance(background_color_, torch.Tensor):
|
||||
background_color = background_color_.to(device)
|
||||
else:
|
||||
background_color = colors.new_tensor(background_color_)
|
||||
|
||||
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||
num_background_pixels = is_background.sum()
|
||||
|
||||
@ -182,13 +187,8 @@ def softmax_rgb_blend(
|
||||
"""
|
||||
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||
background_ = blend_params.background_color
|
||||
if not isinstance(background_, torch.Tensor):
|
||||
background = torch.tensor(background_, dtype=torch.float32, device=device)
|
||||
else:
|
||||
background = background_.to(device)
|
||||
background_color = _get_background_color(blend_params, fragments.pix_to_face.device)
|
||||
|
||||
# Weight for background color
|
||||
eps = 1e-10
|
||||
@ -233,7 +233,7 @@ def softmax_rgb_blend(
|
||||
|
||||
# Sum: weights * textures + background color
|
||||
weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
|
||||
weighted_background = delta * background
|
||||
weighted_background = delta * background_color
|
||||
pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
|
||||
pixel_colors[..., 3] = 1.0 - alpha
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user