From 65620e716c184ec162df04df2f3fe08351a957f5 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Mon, 1 Jun 2020 07:58:44 -0700 Subject: [PATCH] Adding support for changing background color Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered Reviewed By: nikhilaravi Differential Revision: D21746062 fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e --- pytorch3d/renderer/blending.py | 25 +++++++++++++++++++++---- pytorch3d/renderer/mesh/shader.py | 24 ++++++++++++++++++------ tests/test_blending.py | 28 ++++++++++++++++++---------- tests/test_render_meshes.py | 21 ++++++++++++++++++--- 4 files changed, 75 insertions(+), 23 deletions(-) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 87d8f59b..29cb6f12 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -19,7 +19,7 @@ class BlendParams(NamedTuple): background_color: Sequence = (1.0, 1.0, 1.0) -def hard_rgb_blend(colors, fragments) -> torch.Tensor: +def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: """ Naive blending of top K faces to return an RGBA image - **RGB** - choose color of the closest point i.e. K=0 @@ -32,14 +32,31 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor: of the faces (in the packed representation) which overlap each pixel in the image. This is used to determine the output shape. + blend_params: BlendParams instance that contains a background_color + field specifying the color for the background Returns: RGBA pixel_colors: (N, H, W, 4) """ N, H, W, K = fragments.pix_to_face.shape device = fragments.pix_to_face.device - pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) - pixel_colors[..., :3] = colors[..., 0, :] - return pixel_colors + + # Mask for the background. + is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) + + background_color = colors.new_tensor(blend_params.background_color) # (3) + + # Find out how much background_color needs to be expanded to be used for masked_scatter. + num_background_pixels = is_background.sum() + + # Set background color. + pixel_colors = colors[..., 0, :].masked_scatter( + is_background[..., None], + background_color[None, :].expand(num_background_pixels, -1), + ) # (N, H, W, 3) + + # Concat with the alpha channel. + alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device) + return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4) def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index ccdd43c7..ecdc6ebe 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -39,13 +39,16 @@ class HardPhongShader(nn.Module): shader = HardPhongShader(device=torch.device("cuda:0")) """ - def __init__(self, device="cpu", cameras=None, lights=None, materials=None): + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) self.materials = ( materials if materials is not None else Materials(device=device) ) self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -57,6 +60,7 @@ class HardPhongShader(nn.Module): texels = interpolate_vertex_colors(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) colors = phong_shading( meshes=meshes, fragments=fragments, @@ -65,7 +69,7 @@ class HardPhongShader(nn.Module): cameras=cameras, materials=materials, ) - images = hard_rgb_blend(colors, fragments) + images = hard_rgb_blend(colors, fragments, blend_params) return images @@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module): shader = HardGouraudShader(device=torch.device("cuda:0")) """ - def __init__(self, device="cpu", cameras=None, lights=None, materials=None): + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) self.materials = ( materials if materials is not None else Materials(device=device) ) self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -146,6 +153,7 @@ class HardGouraudShader(nn.Module): raise ValueError(msg) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) pixel_colors = gouraud_shading( meshes=meshes, fragments=fragments, @@ -153,7 +161,7 @@ class HardGouraudShader(nn.Module): cameras=cameras, materials=materials, ) - images = hard_rgb_blend(pixel_colors, fragments) + images = hard_rgb_blend(pixel_colors, fragments, blend_params) return images @@ -266,13 +274,16 @@ class HardFlatShader(nn.Module): shader = HardFlatShader(device=torch.device("cuda:0")) """ - def __init__(self, device="cpu", cameras=None, lights=None, materials=None): + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): super().__init__() self.lights = lights if lights is not None else PointLights(device=device) self.materials = ( materials if materials is not None else Materials(device=device) ) self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -283,6 +294,7 @@ class HardFlatShader(nn.Module): texels = interpolate_vertex_colors(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) colors = flat_shading( meshes=meshes, fragments=fragments, @@ -291,7 +303,7 @@ class HardFlatShader(nn.Module): cameras=cameras, materials=materials, ) - images = hard_rgb_blend(colors, fragments) + images = hard_rgb_blend(colors, fragments, blend_params) return images diff --git a/tests/test_blending.py b/tests/test_blending.py index 1e216c46..117ab11b 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -4,6 +4,7 @@ import unittest import numpy as np import torch +from common_testing import TestCaseMixin from pytorch3d.renderer.blending import ( BlendParams, hard_rgb_blend, @@ -128,7 +129,7 @@ def softmax_blend_naive(colors, fragments, blend_params): return pixel_colors -class TestBlending(unittest.TestCase): +class TestBlending(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) @@ -156,7 +157,7 @@ class TestBlending(unittest.TestCase): def test_hard_rgb_blend(self): N, H, W, K = 5, 10, 10, 20 - pix_to_face = torch.ones((N, H, W, K)) + pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K)) bary_coords = torch.ones((N, H, W, K, 3)) fragments = Fragments( pix_to_face=pix_to_face, @@ -164,14 +165,21 @@ class TestBlending(unittest.TestCase): zbuf=pix_to_face, # dummy dists=pix_to_face, # dummy ) - colors = bary_coords.clone() - top_k = torch.randn((K, 3)) - colors[..., :, :] = top_k - images = hard_rgb_blend(colors, fragments) - expected_vals = torch.ones((N, H, W, 4)) - pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :] - expected_vals[..., :3] = pix_cols - self.assertTrue(torch.allclose(images, expected_vals)) + colors = torch.randn((N, H, W, K, 3)) + blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1)) + images = hard_rgb_blend(colors, fragments, blend_params) + + # Examine if the foreground colors are correct. + is_foreground = pix_to_face[..., 0] >= 0 + self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :]) + + # Examine if the background colors are correct. + for i in range(3): # i.e. RGB + channel_color = blend_params.background_color[i] + self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all()) + + # Examine the alpha channel is correct + self.assertTrue(images[..., 3].eq(1).all()) def test_sigmoid_alpha_blend_manual_gradients(self): # Create dummy outputs of rasterization diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 000ad679..87c83409 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -77,6 +77,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=512, blur_radius=0.0, faces_per_pixel=1 ) rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) # Test several shaders shaders = { @@ -85,7 +86,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): "flat": HardFlatShader, } for (name, shader_init) in shaders.items(): - shader = shader_init(lights=lights, cameras=cameras, materials=materials) + shader = shader_init( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) images = renderer(sphere_mesh) filename = "simple_sphere_light_%s%s.png" % (name, postfix) @@ -105,7 +111,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ######################################################## lights.location[..., 2] = -2.0 phong_shader = HardPhongShader( - lights=lights, cameras=cameras, materials=materials + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, ) phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader) images = phong_renderer(sphere_mesh, lights=lights) @@ -162,6 +171,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): materials = Materials(device=device) lights = PointLights(device=device) lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) # Init renderer rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) @@ -171,7 +181,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): "flat": HardFlatShader, } for (name, shader_init) in shaders.items(): - shader = shader_init(lights=lights, cameras=cameras, materials=materials) + shader = shader_init( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) images = renderer(sphere_meshes) image_ref = load_rgb_image(