mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Adding support for changing background color
Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered Reviewed By: nikhilaravi Differential Revision: D21746062 fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e3819a49df
commit
65620e716c
@@ -4,6 +4,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.blending import (
|
||||
BlendParams,
|
||||
hard_rgb_blend,
|
||||
@@ -128,7 +129,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
||||
return pixel_colors
|
||||
|
||||
|
||||
class TestBlending(unittest.TestCase):
|
||||
class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
@@ -156,7 +157,7 @@ class TestBlending(unittest.TestCase):
|
||||
|
||||
def test_hard_rgb_blend(self):
|
||||
N, H, W, K = 5, 10, 10, 20
|
||||
pix_to_face = torch.ones((N, H, W, K))
|
||||
pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
|
||||
bary_coords = torch.ones((N, H, W, K, 3))
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
@@ -164,14 +165,21 @@ class TestBlending(unittest.TestCase):
|
||||
zbuf=pix_to_face, # dummy
|
||||
dists=pix_to_face, # dummy
|
||||
)
|
||||
colors = bary_coords.clone()
|
||||
top_k = torch.randn((K, 3))
|
||||
colors[..., :, :] = top_k
|
||||
images = hard_rgb_blend(colors, fragments)
|
||||
expected_vals = torch.ones((N, H, W, 4))
|
||||
pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :]
|
||||
expected_vals[..., :3] = pix_cols
|
||||
self.assertTrue(torch.allclose(images, expected_vals))
|
||||
colors = torch.randn((N, H, W, K, 3))
|
||||
blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
|
||||
images = hard_rgb_blend(colors, fragments, blend_params)
|
||||
|
||||
# Examine if the foreground colors are correct.
|
||||
is_foreground = pix_to_face[..., 0] >= 0
|
||||
self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :])
|
||||
|
||||
# Examine if the background colors are correct.
|
||||
for i in range(3): # i.e. RGB
|
||||
channel_color = blend_params.background_color[i]
|
||||
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
|
||||
|
||||
# Examine the alpha channel is correct
|
||||
self.assertTrue(images[..., 3].eq(1).all())
|
||||
|
||||
def test_sigmoid_alpha_blend_manual_gradients(self):
|
||||
# Create dummy outputs of rasterization
|
||||
|
||||
@@ -77,6 +77,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||
|
||||
# Test several shaders
|
||||
shaders = {
|
||||
@@ -85,7 +86,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
"flat": HardFlatShader,
|
||||
}
|
||||
for (name, shader_init) in shaders.items():
|
||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
||||
shader = shader_init(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||
@@ -105,7 +111,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
########################################################
|
||||
lights.location[..., 2] = -2.0
|
||||
phong_shader = HardPhongShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
||||
images = phong_renderer(sphere_mesh, lights=lights)
|
||||
@@ -162,6 +171,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||
|
||||
# Init renderer
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
@@ -171,7 +181,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
"flat": HardFlatShader,
|
||||
}
|
||||
for (name, shader_init) in shaders.items():
|
||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
||||
shader = shader_init(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_meshes)
|
||||
image_ref = load_rgb_image(
|
||||
|
||||
Reference in New Issue
Block a user