mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Adding support for changing background color
Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered Reviewed By: nikhilaravi Differential Revision: D21746062 fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e
This commit is contained in:
parent
e3819a49df
commit
65620e716c
@ -19,7 +19,7 @@ class BlendParams(NamedTuple):
|
|||||||
background_color: Sequence = (1.0, 1.0, 1.0)
|
background_color: Sequence = (1.0, 1.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Naive blending of top K faces to return an RGBA image
|
Naive blending of top K faces to return an RGBA image
|
||||||
- **RGB** - choose color of the closest point i.e. K=0
|
- **RGB** - choose color of the closest point i.e. K=0
|
||||||
@ -32,14 +32,31 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
|||||||
of the faces (in the packed representation) which
|
of the faces (in the packed representation) which
|
||||||
overlap each pixel in the image. This is used to
|
overlap each pixel in the image. This is used to
|
||||||
determine the output shape.
|
determine the output shape.
|
||||||
|
blend_params: BlendParams instance that contains a background_color
|
||||||
|
field specifying the color for the background
|
||||||
Returns:
|
Returns:
|
||||||
RGBA pixel_colors: (N, H, W, 4)
|
RGBA pixel_colors: (N, H, W, 4)
|
||||||
"""
|
"""
|
||||||
N, H, W, K = fragments.pix_to_face.shape
|
N, H, W, K = fragments.pix_to_face.shape
|
||||||
device = fragments.pix_to_face.device
|
device = fragments.pix_to_face.device
|
||||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
|
|
||||||
pixel_colors[..., :3] = colors[..., 0, :]
|
# Mask for the background.
|
||||||
return pixel_colors
|
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||||
|
|
||||||
|
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||||
|
|
||||||
|
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||||
|
num_background_pixels = is_background.sum()
|
||||||
|
|
||||||
|
# Set background color.
|
||||||
|
pixel_colors = colors[..., 0, :].masked_scatter(
|
||||||
|
is_background[..., None],
|
||||||
|
background_color[None, :].expand(num_background_pixels, -1),
|
||||||
|
) # (N, H, W, 3)
|
||||||
|
|
||||||
|
# Concat with the alpha channel.
|
||||||
|
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
|
||||||
|
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||||
|
@ -39,13 +39,16 @@ class HardPhongShader(nn.Module):
|
|||||||
shader = HardPhongShader(device=torch.device("cuda:0"))
|
shader = HardPhongShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
def __init__(
|
||||||
|
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -57,6 +60,7 @@ class HardPhongShader(nn.Module):
|
|||||||
texels = interpolate_vertex_colors(fragments, meshes)
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||||
colors = phong_shading(
|
colors = phong_shading(
|
||||||
meshes=meshes,
|
meshes=meshes,
|
||||||
fragments=fragments,
|
fragments=fragments,
|
||||||
@ -65,7 +69,7 @@ class HardPhongShader(nn.Module):
|
|||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
materials=materials,
|
materials=materials,
|
||||||
)
|
)
|
||||||
images = hard_rgb_blend(colors, fragments)
|
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module):
|
|||||||
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
shader = HardGouraudShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
def __init__(
|
||||||
|
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -146,6 +153,7 @@ class HardGouraudShader(nn.Module):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||||
pixel_colors = gouraud_shading(
|
pixel_colors = gouraud_shading(
|
||||||
meshes=meshes,
|
meshes=meshes,
|
||||||
fragments=fragments,
|
fragments=fragments,
|
||||||
@ -153,7 +161,7 @@ class HardGouraudShader(nn.Module):
|
|||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
materials=materials,
|
materials=materials,
|
||||||
)
|
)
|
||||||
images = hard_rgb_blend(pixel_colors, fragments)
|
images = hard_rgb_blend(pixel_colors, fragments, blend_params)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
@ -266,13 +274,16 @@ class HardFlatShader(nn.Module):
|
|||||||
shader = HardFlatShader(device=torch.device("cuda:0"))
|
shader = HardFlatShader(device=torch.device("cuda:0"))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
|
def __init__(
|
||||||
|
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lights = lights if lights is not None else PointLights(device=device)
|
self.lights = lights if lights is not None else PointLights(device=device)
|
||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = cameras
|
self.cameras = cameras
|
||||||
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -283,6 +294,7 @@ class HardFlatShader(nn.Module):
|
|||||||
texels = interpolate_vertex_colors(fragments, meshes)
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
|
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||||
colors = flat_shading(
|
colors = flat_shading(
|
||||||
meshes=meshes,
|
meshes=meshes,
|
||||||
fragments=fragments,
|
fragments=fragments,
|
||||||
@ -291,7 +303,7 @@ class HardFlatShader(nn.Module):
|
|||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
materials=materials,
|
materials=materials,
|
||||||
)
|
)
|
||||||
images = hard_rgb_blend(colors, fragments)
|
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
from pytorch3d.renderer.blending import (
|
from pytorch3d.renderer.blending import (
|
||||||
BlendParams,
|
BlendParams,
|
||||||
hard_rgb_blend,
|
hard_rgb_blend,
|
||||||
@ -128,7 +129,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
|||||||
return pixel_colors
|
return pixel_colors
|
||||||
|
|
||||||
|
|
||||||
class TestBlending(unittest.TestCase):
|
class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
@ -156,7 +157,7 @@ class TestBlending(unittest.TestCase):
|
|||||||
|
|
||||||
def test_hard_rgb_blend(self):
|
def test_hard_rgb_blend(self):
|
||||||
N, H, W, K = 5, 10, 10, 20
|
N, H, W, K = 5, 10, 10, 20
|
||||||
pix_to_face = torch.ones((N, H, W, K))
|
pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
|
||||||
bary_coords = torch.ones((N, H, W, K, 3))
|
bary_coords = torch.ones((N, H, W, K, 3))
|
||||||
fragments = Fragments(
|
fragments = Fragments(
|
||||||
pix_to_face=pix_to_face,
|
pix_to_face=pix_to_face,
|
||||||
@ -164,14 +165,21 @@ class TestBlending(unittest.TestCase):
|
|||||||
zbuf=pix_to_face, # dummy
|
zbuf=pix_to_face, # dummy
|
||||||
dists=pix_to_face, # dummy
|
dists=pix_to_face, # dummy
|
||||||
)
|
)
|
||||||
colors = bary_coords.clone()
|
colors = torch.randn((N, H, W, K, 3))
|
||||||
top_k = torch.randn((K, 3))
|
blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
|
||||||
colors[..., :, :] = top_k
|
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||||
images = hard_rgb_blend(colors, fragments)
|
|
||||||
expected_vals = torch.ones((N, H, W, 4))
|
# Examine if the foreground colors are correct.
|
||||||
pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :]
|
is_foreground = pix_to_face[..., 0] >= 0
|
||||||
expected_vals[..., :3] = pix_cols
|
self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :])
|
||||||
self.assertTrue(torch.allclose(images, expected_vals))
|
|
||||||
|
# Examine if the background colors are correct.
|
||||||
|
for i in range(3): # i.e. RGB
|
||||||
|
channel_color = blend_params.background_color[i]
|
||||||
|
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
|
||||||
|
|
||||||
|
# Examine the alpha channel is correct
|
||||||
|
self.assertTrue(images[..., 3].eq(1).all())
|
||||||
|
|
||||||
def test_sigmoid_alpha_blend_manual_gradients(self):
|
def test_sigmoid_alpha_blend_manual_gradients(self):
|
||||||
# Create dummy outputs of rasterization
|
# Create dummy outputs of rasterization
|
||||||
|
@ -77,6 +77,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||||
)
|
)
|
||||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||||
|
|
||||||
# Test several shaders
|
# Test several shaders
|
||||||
shaders = {
|
shaders = {
|
||||||
@ -85,7 +86,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
"flat": HardFlatShader,
|
"flat": HardFlatShader,
|
||||||
}
|
}
|
||||||
for (name, shader_init) in shaders.items():
|
for (name, shader_init) in shaders.items():
|
||||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
shader = shader_init(
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
blend_params=blend_params,
|
||||||
|
)
|
||||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
images = renderer(sphere_mesh)
|
images = renderer(sphere_mesh)
|
||||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||||
@ -105,7 +111,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
########################################################
|
########################################################
|
||||||
lights.location[..., 2] = -2.0
|
lights.location[..., 2] = -2.0
|
||||||
phong_shader = HardPhongShader(
|
phong_shader = HardPhongShader(
|
||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
blend_params=blend_params,
|
||||||
)
|
)
|
||||||
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
||||||
images = phong_renderer(sphere_mesh, lights=lights)
|
images = phong_renderer(sphere_mesh, lights=lights)
|
||||||
@ -162,6 +171,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
materials = Materials(device=device)
|
materials = Materials(device=device)
|
||||||
lights = PointLights(device=device)
|
lights = PointLights(device=device)
|
||||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
@ -171,7 +181,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
"flat": HardFlatShader,
|
"flat": HardFlatShader,
|
||||||
}
|
}
|
||||||
for (name, shader_init) in shaders.items():
|
for (name, shader_init) in shaders.items():
|
||||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
shader = shader_init(
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
blend_params=blend_params,
|
||||||
|
)
|
||||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
images = renderer(sphere_meshes)
|
images = renderer(sphere_meshes)
|
||||||
image_ref = load_rgb_image(
|
image_ref = load_rgb_image(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user