diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 3e8c3e33..7d6f5acf 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast( expand_sizes = (N,) + (-1,) * len(c.shape[1:]) args_Nd.append(c.expand(*expand_sizes)) - if len(args) == 1: - args_Nd = args_Nd[0] # Return the first element - return args_Nd diff --git a/tests/test_lighting.py b/tests/test_lighting.py index 4f8542d5..dabafdca 100644 --- a/tests/test_lighting.py +++ b/tests/test_lighting.py @@ -9,7 +9,7 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin -from pytorch3d.renderer.lighting import DirectionalLights, PointLights +from pytorch3d.renderer.lighting import AmbientLights, DirectionalLights, PointLights from pytorch3d.transforms import RotateAxisAngle @@ -121,6 +121,17 @@ class TestLights(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): PointLights(location=torch.randn(10, 4)) + def test_initialize_ambient(self): + N = 13 + color = 0.8 * torch.ones((N, 3)) + lights = AmbientLights(ambient_color=color) + self.assertEqual(len(lights), N) + self.assertClose(lights.ambient_color, color) + + lights = AmbientLights(ambient_color=color[:1]) + self.assertEqual(len(lights), 1) + self.assertClose(lights.ambient_color, color[:1]) + class TestDiffuseLighting(TestCaseMixin, unittest.TestCase): def test_diffuse_directional_lights(self):