diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 382b0b41..b89d166d 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -34,7 +34,7 @@ from .implicit import ( ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, ) -from .lighting import DirectionalLights, PointLights, diffuse, specular +from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular from .materials import Materials from .mesh import ( HardFlatShader, diff --git a/pytorch3d/renderer/lighting.py b/pytorch3d/renderer/lighting.py index 2eee6b85..05ad4858 100644 --- a/pytorch3d/renderer/lighting.py +++ b/pytorch3d/renderer/lighting.py @@ -158,7 +158,7 @@ class DirectionalLights(TensorProperties): diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), direction=((0, 1, 0),), - device: str = "cpu", + device="cpu", ): """ Args: @@ -219,7 +219,7 @@ class PointLights(TensorProperties): diffuse_color=((0.3, 0.3, 0.3),), specular_color=((0.2, 0.2, 0.2),), location=((0, 1, 0),), - device: str = "cpu", + device="cpu", ): """ Args: @@ -268,6 +268,42 @@ class PointLights(TensorProperties): ) +class AmbientLights(TensorProperties): + """ + A light object representing the same color of light everywhere. + By default, this is white, which effectively means lighting is + not used in rendering. + """ + + def __init__(self, *, ambient_color=None, device="cpu"): + """ + If ambient_color is provided, it should be a sequence of + triples of floats. + + Args: + ambient_color: RGB color + device: torch.device on which the tensors should be located + + The ambient_color if provided, should be + - 3 element tuple/list or list of lists + - torch tensor of shape (1, 3) + - torch tensor of shape (N, 3) + """ + if ambient_color is None: + ambient_color = ((1.0, 1.0, 1.0),) + super().__init__(ambient_color=ambient_color, device=device) + + def clone(self): + other = self.__class__(device=self.device) + return super().clone(other) + + def diffuse(self, normals, points) -> torch.Tensor: + return torch.zeros_like(points) + + def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: + return torch.zeros_like(points) + + def _validate_light_properties(obj): props = ("ambient_color", "diffuse_color", "specular_color") for n in props: diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index c9509d78..e2341cc8 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -24,7 +24,7 @@ from pytorch3d.renderer.cameras import ( PerspectiveCameras, look_at_view_transform, ) -from pytorch3d.renderer.lighting import PointLights +from pytorch3d.renderer.lighting import AmbientLights, PointLights from pytorch3d.renderer.materials import Materials from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings @@ -626,12 +626,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=256, blur_radius=0.0, faces_per_pixel=1 ) - lights = PointLights( - device=device, - ambient_color=((1.0, 1.0, 1.0),), - diffuse_color=((0.0, 0.0, 0.0),), - specular_color=((0.0, 0.0, 0.0),), - ) + lights = AmbientLights(device=device) blend_params = BlendParams( sigma=1e-1, gamma=1e-4, @@ -780,12 +775,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): image_size=256, blur_radius=0.0, faces_per_pixel=1 ) - lights = PointLights( - device=device, - ambient_color=((1.0, 1.0, 1.0),), - diffuse_color=((0.0, 0.0, 0.0),), - specular_color=((0.0, 0.0, 0.0),), - ) + lights = AmbientLights(device=device) blend_params = BlendParams( sigma=1e-1, gamma=1e-4, @@ -863,12 +853,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): perspective_correct=False, ) - lights = PointLights( - device=device, - ambient_color=((1.0, 1.0, 1.0),), - diffuse_color=((0.0, 0.0, 0.0),), - specular_color=((0.0, 0.0, 0.0),), - ) + lights = AmbientLights(device=device) blend_params = BlendParams( sigma=1e-1, gamma=1e-4,