mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Adding support for changing background color
Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered Reviewed By: nikhilaravi Differential Revision: D21746062 fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e3819a49df
commit
65620e716c
@@ -19,7 +19,7 @@ class BlendParams(NamedTuple):
|
||||
background_color: Sequence = (1.0, 1.0, 1.0)
|
||||
|
||||
|
||||
def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
||||
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
"""
|
||||
Naive blending of top K faces to return an RGBA image
|
||||
- **RGB** - choose color of the closest point i.e. K=0
|
||||
@@ -32,14 +32,31 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image. This is used to
|
||||
determine the output shape.
|
||||
blend_params: BlendParams instance that contains a background_color
|
||||
field specifying the color for the background
|
||||
Returns:
|
||||
RGBA pixel_colors: (N, H, W, 4)
|
||||
"""
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
|
||||
pixel_colors[..., :3] = colors[..., 0, :]
|
||||
return pixel_colors
|
||||
|
||||
# Mask for the background.
|
||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||
|
||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||
|
||||
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||
num_background_pixels = is_background.sum()
|
||||
|
||||
# Set background color.
|
||||
pixel_colors = colors[..., 0, :].masked_scatter(
|
||||
is_background[..., None],
|
||||
background_color[None, :].expand(num_background_pixels, -1),
|
||||
) # (N, H, W, 3)
|
||||
|
||||
# Concat with the alpha channel.
|
||||
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
|
||||
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
|
||||
|
||||
|
||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
|
||||
@@ -39,13 +39,16 @@ class HardPhongShader(nn.Module):
|
||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
@@ -57,6 +60,7 @@ class HardPhongShader(nn.Module):
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
colors = phong_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@@ -65,7 +69,7 @@ class HardPhongShader(nn.Module):
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = hard_rgb_blend(colors, fragments)
|
||||
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
@@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module):
|
||||
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
@@ -146,6 +153,7 @@ class HardGouraudShader(nn.Module):
|
||||
raise ValueError(msg)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
pixel_colors = gouraud_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@@ -153,7 +161,7 @@ class HardGouraudShader(nn.Module):
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = hard_rgb_blend(pixel_colors, fragments)
|
||||
images = hard_rgb_blend(pixel_colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
@@ -266,13 +274,16 @@ class HardFlatShader(nn.Module):
|
||||
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
@@ -283,6 +294,7 @@ class HardFlatShader(nn.Module):
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
colors = flat_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@@ -291,7 +303,7 @@ class HardFlatShader(nn.Module):
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = hard_rgb_blend(colors, fragments)
|
||||
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user