diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 1816e767..0865547a 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -3,7 +3,7 @@ import numpy as np -from typing import NamedTuple +from typing import NamedTuple, Sequence import torch # Example functions for blending the top K colors per pixel using the outputs @@ -15,7 +15,7 @@ import torch class BlendParams(NamedTuple): sigma: float = 1e-4 gamma: float = 1e-4 - background_color = (1.0, 1.0, 1.0) + background_color: Sequence = (1.0, 1.0, 1.0) def hard_rgb_blend(colors, fragments) -> torch.Tensor: diff --git a/tests/test_blending.py b/tests/test_blending.py index 81f7516c..87efad60 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -234,3 +234,12 @@ class TestBlending(unittest.TestCase): self.assertTrue(torch.allclose(dists1.grad, dists2.grad, atol=2e-5)) self.assertTrue(torch.allclose(zbuf1.grad, zbuf2.grad, atol=2e-5)) + + def test_blend_params(self): + """Test colour parameter of BlendParams(). + Assert passed value overrides default value. + """ + bp_default = BlendParams() + bp_new = BlendParams(background_color=(0.5, 0.5, 0.5)) + self.assertEqual(bp_new.background_color, (0.5, 0.5, 0.5)) + self.assertEqual(bp_default.background_color, (1.0, 1.0, 1.0))