Adding support for changing background color

Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered

Reviewed By: nikhilaravi

Differential Revision: D21746062

fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e
This commit is contained in:
Luya Gao 2020-06-01 07:58:44 -07:00 committed by Facebook GitHub Bot
parent e3819a49df
commit 65620e716c
4 changed files with 75 additions and 23 deletions

View File

@ -19,7 +19,7 @@ class BlendParams(NamedTuple):
background_color: Sequence = (1.0, 1.0, 1.0) background_color: Sequence = (1.0, 1.0, 1.0)
def hard_rgb_blend(colors, fragments) -> torch.Tensor: def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
""" """
Naive blending of top K faces to return an RGBA image Naive blending of top K faces to return an RGBA image
- **RGB** - choose color of the closest point i.e. K=0 - **RGB** - choose color of the closest point i.e. K=0
@ -32,14 +32,31 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
of the faces (in the packed representation) which of the faces (in the packed representation) which
overlap each pixel in the image. This is used to overlap each pixel in the image. This is used to
determine the output shape. determine the output shape.
blend_params: BlendParams instance that contains a background_color
field specifying the color for the background
Returns: Returns:
RGBA pixel_colors: (N, H, W, 4) RGBA pixel_colors: (N, H, W, 4)
""" """
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
pixel_colors[..., :3] = colors[..., 0, :] # Mask for the background.
return pixel_colors is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
background_color = colors.new_tensor(blend_params.background_color) # (3)
# Find out how much background_color needs to be expanded to be used for masked_scatter.
num_background_pixels = is_background.sum()
# Set background color.
pixel_colors = colors[..., 0, :].masked_scatter(
is_background[..., None],
background_color[None, :].expand(num_background_pixels, -1),
) # (N, H, W, 3)
# Concat with the alpha channel.
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:

View File

@ -39,13 +39,16 @@ class HardPhongShader(nn.Module):
shader = HardPhongShader(device=torch.device("cuda:0")) shader = HardPhongShader(device=torch.device("cuda:0"))
""" """
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
@ -57,6 +60,7 @@ class HardPhongShader(nn.Module):
texels = interpolate_vertex_colors(fragments, meshes) texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = phong_shading( colors = phong_shading(
meshes=meshes, meshes=meshes,
fragments=fragments, fragments=fragments,
@ -65,7 +69,7 @@ class HardPhongShader(nn.Module):
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
) )
images = hard_rgb_blend(colors, fragments) images = hard_rgb_blend(colors, fragments, blend_params)
return images return images
@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module):
shader = HardGouraudShader(device=torch.device("cuda:0")) shader = HardGouraudShader(device=torch.device("cuda:0"))
""" """
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
@ -146,6 +153,7 @@ class HardGouraudShader(nn.Module):
raise ValueError(msg) raise ValueError(msg)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
pixel_colors = gouraud_shading( pixel_colors = gouraud_shading(
meshes=meshes, meshes=meshes,
fragments=fragments, fragments=fragments,
@ -153,7 +161,7 @@ class HardGouraudShader(nn.Module):
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
) )
images = hard_rgb_blend(pixel_colors, fragments) images = hard_rgb_blend(pixel_colors, fragments, blend_params)
return images return images
@ -266,13 +274,16 @@ class HardFlatShader(nn.Module):
shader = HardFlatShader(device=torch.device("cuda:0")) shader = HardFlatShader(device=torch.device("cuda:0"))
""" """
def __init__(self, device="cpu", cameras=None, lights=None, materials=None): def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
materials if materials is not None else Materials(device=device) materials if materials is not None else Materials(device=device)
) )
self.cameras = cameras self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras) cameras = kwargs.get("cameras", self.cameras)
@ -283,6 +294,7 @@ class HardFlatShader(nn.Module):
texels = interpolate_vertex_colors(fragments, meshes) texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights) lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials) materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = flat_shading( colors = flat_shading(
meshes=meshes, meshes=meshes,
fragments=fragments, fragments=fragments,
@ -291,7 +303,7 @@ class HardFlatShader(nn.Module):
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
) )
images = hard_rgb_blend(colors, fragments) images = hard_rgb_blend(colors, fragments, blend_params)
return images return images

View File

@ -4,6 +4,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.blending import ( from pytorch3d.renderer.blending import (
BlendParams, BlendParams,
hard_rgb_blend, hard_rgb_blend,
@ -128,7 +129,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
return pixel_colors return pixel_colors
class TestBlending(unittest.TestCase): class TestBlending(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
torch.manual_seed(42) torch.manual_seed(42)
@ -156,7 +157,7 @@ class TestBlending(unittest.TestCase):
def test_hard_rgb_blend(self): def test_hard_rgb_blend(self):
N, H, W, K = 5, 10, 10, 20 N, H, W, K = 5, 10, 10, 20
pix_to_face = torch.ones((N, H, W, K)) pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
bary_coords = torch.ones((N, H, W, K, 3)) bary_coords = torch.ones((N, H, W, K, 3))
fragments = Fragments( fragments = Fragments(
pix_to_face=pix_to_face, pix_to_face=pix_to_face,
@ -164,14 +165,21 @@ class TestBlending(unittest.TestCase):
zbuf=pix_to_face, # dummy zbuf=pix_to_face, # dummy
dists=pix_to_face, # dummy dists=pix_to_face, # dummy
) )
colors = bary_coords.clone() colors = torch.randn((N, H, W, K, 3))
top_k = torch.randn((K, 3)) blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
colors[..., :, :] = top_k images = hard_rgb_blend(colors, fragments, blend_params)
images = hard_rgb_blend(colors, fragments)
expected_vals = torch.ones((N, H, W, 4)) # Examine if the foreground colors are correct.
pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :] is_foreground = pix_to_face[..., 0] >= 0
expected_vals[..., :3] = pix_cols self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :])
self.assertTrue(torch.allclose(images, expected_vals))
# Examine if the background colors are correct.
for i in range(3): # i.e. RGB
channel_color = blend_params.background_color[i]
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
# Examine the alpha channel is correct
self.assertTrue(images[..., 3].eq(1).all())
def test_sigmoid_alpha_blend_manual_gradients(self): def test_sigmoid_alpha_blend_manual_gradients(self):
# Create dummy outputs of rasterization # Create dummy outputs of rasterization

View File

@ -77,6 +77,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
image_size=512, blur_radius=0.0, faces_per_pixel=1 image_size=512, blur_radius=0.0, faces_per_pixel=1
) )
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Test several shaders # Test several shaders
shaders = { shaders = {
@ -85,7 +86,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"flat": HardFlatShader, "flat": HardFlatShader,
} }
for (name, shader_init) in shaders.items(): for (name, shader_init) in shaders.items():
shader = shader_init(lights=lights, cameras=cameras, materials=materials) shader = shader_init(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh) images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix) filename = "simple_sphere_light_%s%s.png" % (name, postfix)
@ -105,7 +111,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
######################################################## ########################################################
lights.location[..., 2] = -2.0 lights.location[..., 2] = -2.0
phong_shader = HardPhongShader( phong_shader = HardPhongShader(
lights=lights, cameras=cameras, materials=materials lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
) )
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader) phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights) images = phong_renderer(sphere_mesh, lights=lights)
@ -162,6 +171,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
materials = Materials(device=device) materials = Materials(device=device)
lights = PointLights(device=device) lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Init renderer # Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
@ -171,7 +181,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"flat": HardFlatShader, "flat": HardFlatShader,
} }
for (name, shader_init) in shaders.items(): for (name, shader_init) in shaders.items():
shader = shader_init(lights=lights, cameras=cameras, materials=materials) shader = shader_init(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes) images = renderer(sphere_meshes)
image_ref = load_rgb_image( image_ref = load_rgb_image(