From bc8361fa471e39280b1c5e3717309167b08ffed5 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Mon, 14 Jun 2021 11:47:35 -0700 Subject: [PATCH] Lighting broadcasting bug fix Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests. Reviewed By: bottler Differential Revision: D28997941 fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8 --- pytorch3d/renderer/blending.py | 18 +++++++--- pytorch3d/renderer/lighting.py | 18 ++++++++-- pytorch3d/renderer/mesh/shading.py | 10 ++++-- tests/test_render_meshes.py | 58 ++++++++++++++++++------------ 4 files changed, 73 insertions(+), 31 deletions(-) diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index d91d556e..8c9205ec 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -1,12 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import NamedTuple, Sequence +from typing import NamedTuple, Sequence, Union import torch from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. - # Example functions for blending the top K colors per pixel using the outputs # from rasterization. # NOTE: All blending function should return an RGBA image per batch element @@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: def softmax_rgb_blend( - colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100 + colors, + fragments, + blend_params, + znear: Union[float, torch.Tensor] = 1.0, + zfar: Union[float, torch.Tensor] = 100, ) -> torch.Tensor: """ RGB and alpha channel blending to return an RGBA image based on the method @@ -184,11 +187,16 @@ def softmax_rgb_blend( # overflow. zbuf shape (N, H, W, K), find max over K. # TODO: there may still be some instability in the exponent calculation. + # Reshape to be compatible with (N, H, W, K) values in fragments + if torch.is_tensor(zfar): + # pyre-fixme[16] + zfar = zfar[:, None, None, None] + if torch.is_tensor(znear): + znear = znear[:, None, None, None] + z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask # pyre-fixme[16]: `Tuple` has no attribute `values`. - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`. z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps) - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`. weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) # Also apply exp normalize trick for the background color weight. diff --git a/pytorch3d/renderer/lighting.py b/pytorch3d/renderer/lighting.py index 915f7141..fc690573 100644 --- a/pytorch3d/renderer/lighting.py +++ b/pytorch3d/renderer/lighting.py @@ -253,12 +253,26 @@ class PointLights(TensorProperties): other = self.__class__(device=self.device) return super().clone(other) + def reshape_location(self, points) -> torch.Tensor: + """ + Reshape the location tensor to have dimensions + compatible with the points which can either be of + shape (P, 3) or (N, H, W, K, 3). + """ + if self.location.ndim == points.ndim: + # pyre-fixme[7] + return self.location + # pyre-fixme[29] + return self.location[:, None, None, None, :] + def diffuse(self, normals, points) -> torch.Tensor: - direction = self.location - points + location = self.reshape_location(points) + direction = location - points return diffuse(normals=normals, color=self.diffuse_color, direction=direction) def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: - direction = self.location - points + location = self.reshape_location(points) + direction = location - points return specular( points=points, normals=normals, diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index 2d248823..20c34bea 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -14,8 +14,8 @@ def _apply_lighting( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: - points: torch tensor of shape (N, P, 3) or (P, 3). - normals: torch tensor of shape (N, P, 3) or (P, 3) + points: torch tensor of shape (N, ..., 3) or (P, 3). + normals: torch tensor of shape (N, ..., 3) or (P, 3) lights: instance of the Lights class. cameras: instance of the Cameras class. materials: instance of the Materials class. @@ -35,6 +35,7 @@ def _apply_lighting( ambient_color = materials.ambient_color * lights.ambient_color diffuse_color = materials.diffuse_color * light_diffuse specular_color = materials.specular_color * light_specular + if normals.dim() == 2 and points.dim() == 2: # If given packed inputs remove batch dim in output. return ( @@ -42,6 +43,11 @@ def _apply_lighting( diffuse_color.squeeze(), specular_color.squeeze(), ) + + if ambient_color.ndim != diffuse_color.ndim: + # Reshape from (N, 3) to have dimensions compatible with + # diffuse_color which is of shape (N, H, W, K, 3) + ambient_color = ambient_color[:, None, None, None, :] return ambient_color, diffuse_color, specular_color diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 64bbe7fd..71c54b6d 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -6,6 +6,7 @@ Sanity checks for output images from the renderer. """ import os import unittest +from collections import namedtuple import numpy as np import torch @@ -53,6 +54,8 @@ DEBUG = False DATA_DIR = get_tests_dir() / "data" TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data" +ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"]) + class TestRenderMeshes(TestCaseMixin, unittest.TestCase): def test_simple_sphere(self, elevated_camera=False, check_depth=False): @@ -107,13 +110,13 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) # Test several shaders - shaders = { - "phong": HardPhongShader, - "gouraud": HardGouraudShader, - "flat": HardFlatShader, - } - for (name, shader_init) in shaders.items(): - shader = shader_init( + shader_tests = [ + ShaderTest(HardPhongShader, "phong", "hard_phong"), + ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"), + ShaderTest(HardFlatShader, "flat", "hard_flat"), + ] + for test in shader_tests: + shader = test.shader( lights=lights, cameras=cameras, materials=materials, @@ -135,7 +138,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): rgb = images[0, ..., :3].squeeze().cpu() filename = "simple_sphere_light_%s%s%s.png" % ( - name, + test.reference_name, postfix, cam_type.__name__, ) @@ -144,7 +147,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): self.assertClose(rgb, image_ref, atol=0.05) if DEBUG: - filename = "DEBUG_%s" % filename + debug_filename = "simple_sphere_light_%s%s%s.png" % ( + test.debug_name, + postfix, + cam_type.__name__, + ) + filename = "DEBUG_%s" % debug_filename Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( DATA_DIR / filename ) @@ -269,7 +277,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): def test_simple_sphere_batched(self): """ Test a mesh with vertex textures can be extended to form a batch, and - is rendered correctly with Phong, Gouraud and Flat Shaders. + is rendered correctly with Phong, Gouraud and Flat Shaders with batched + lighting and hard and soft blending. """ batch_size = 5 device = torch.device("cuda:0") @@ -291,24 +300,28 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): R, T = look_at_view_transform(dist, elev, azim) cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( - image_size=512, blur_radius=0.0, faces_per_pixel=1 + image_size=512, blur_radius=0.0, faces_per_pixel=4 ) # Init shader settings materials = Materials(device=device) - lights = PointLights(device=device) - lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + lights_location = torch.tensor([0.0, 0.0, +2.0], device=device) + lights_location = lights_location[None].expand(batch_size, -1) + lights = PointLights(device=device, location=lights_location) blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) # Init renderer rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) - shaders = { - "phong": HardPhongShader, - "gouraud": HardGouraudShader, - "flat": HardFlatShader, - } - for (name, shader_init) in shaders.items(): - shader = shader_init( + shader_tests = [ + ShaderTest(HardPhongShader, "phong", "hard_phong"), + ShaderTest(SoftPhongShader, "phong", "soft_phong"), + ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"), + ShaderTest(HardFlatShader, "flat", "hard_flat"), + ] + for test in shader_tests: + reference_name = test.reference_name + debug_name = test.debug_name + shader = test.shader( lights=lights, cameras=cameras, materials=materials, @@ -317,14 +330,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) images = renderer(sphere_meshes) image_ref = load_rgb_image( - "test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__), + "test_simple_sphere_light_%s_%s.png" + % (reference_name, type(cameras).__name__), DATA_DIR, ) for i in range(batch_size): rgb = images[i, ..., :3].squeeze().cpu() if i == 0 and DEBUG: filename = "DEBUG_simple_sphere_batched_%s_%s.png" % ( - name, + debug_name, type(cameras).__name__, ) Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(