From 9e2bc3a17faf7dfffad9b0803f335da328b08a61 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 20 Jan 2022 09:43:32 -0800 Subject: [PATCH] ambient lights batching #1043 Summary: convert_to_tensors_and_broadcast had a special case for a single input, which is not used anywhere except fails to do the right thing if a TensorProperties has only one kwarg. At the moment AmbientLights may be the only way to hit the problem. Fix by removing the special case. Fixes https://github.com/facebookresearch/pytorch3d/issues/1043 Reviewed By: nikhilaravi Differential Revision: D33638345 fbshipit-source-id: 7a6695f44242e650504320f73b6da74254d49ac7 --- pytorch3d/renderer/utils.py | 3 --- tests/test_lighting.py | 13 ++++++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 3e8c3e33..7d6f5acf 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -349,7 +349,4 @@ def convert_to_tensors_and_broadcast( expand_sizes = (N,) + (-1,) * len(c.shape[1:]) args_Nd.append(c.expand(*expand_sizes)) - if len(args) == 1: - args_Nd = args_Nd[0] # Return the first element - return args_Nd diff --git a/tests/test_lighting.py b/tests/test_lighting.py index 4f8542d5..dabafdca 100644 --- a/tests/test_lighting.py +++ b/tests/test_lighting.py @@ -9,7 +9,7 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin -from pytorch3d.renderer.lighting import DirectionalLights, PointLights +from pytorch3d.renderer.lighting import AmbientLights, DirectionalLights, PointLights from pytorch3d.transforms import RotateAxisAngle @@ -121,6 +121,17 @@ class TestLights(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): PointLights(location=torch.randn(10, 4)) + def test_initialize_ambient(self): + N = 13 + color = 0.8 * torch.ones((N, 3)) + lights = AmbientLights(ambient_color=color) + self.assertEqual(len(lights), N) + self.assertClose(lights.ambient_color, color) + + lights = AmbientLights(ambient_color=color[:1]) + self.assertEqual(len(lights), 1) + self.assertClose(lights.ambient_color, color[:1]) + class TestDiffuseLighting(TestCaseMixin, unittest.TestCase): def test_diffuse_directional_lights(self):