Lighting broadcasting bug fix

Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests.

Reviewed By: bottler

Differential Revision: D28997941

fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
This commit is contained in:
Nikhila Ravi 2021-06-14 11:47:35 -07:00 committed by Facebook GitHub Bot
parent 9de627e01b
commit bc8361fa47
4 changed files with 73 additions and 31 deletions

View File

@ -1,12 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence, Union
import torch import torch
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`. from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
# Example functions for blending the top K colors per pixel using the outputs # Example functions for blending the top K colors per pixel using the outputs
# from rasterization. # from rasterization.
# NOTE: All blending function should return an RGBA image per batch element # NOTE: All blending function should return an RGBA image per batch element
@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
def softmax_rgb_blend( def softmax_rgb_blend(
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100 colors,
fragments,
blend_params,
znear: Union[float, torch.Tensor] = 1.0,
zfar: Union[float, torch.Tensor] = 100,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
RGB and alpha channel blending to return an RGBA image based on the method RGB and alpha channel blending to return an RGBA image based on the method
@ -184,11 +187,16 @@ def softmax_rgb_blend(
# overflow. zbuf shape (N, H, W, K), find max over K. # overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation. # TODO: there may still be some instability in the exponent calculation.
# Reshape to be compatible with (N, H, W, K) values in fragments
if torch.is_tensor(zfar):
# pyre-fixme[16]
zfar = zfar[:, None, None, None]
if torch.is_tensor(znear):
znear = znear[:, None, None, None]
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
# pyre-fixme[16]: `Tuple` has no attribute `values`. # pyre-fixme[16]: `Tuple` has no attribute `values`.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps) z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma) weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Also apply exp normalize trick for the background color weight. # Also apply exp normalize trick for the background color weight.

View File

@ -253,12 +253,26 @@ class PointLights(TensorProperties):
other = self.__class__(device=self.device) other = self.__class__(device=self.device)
return super().clone(other) return super().clone(other)
def reshape_location(self, points) -> torch.Tensor:
"""
Reshape the location tensor to have dimensions
compatible with the points which can either be of
shape (P, 3) or (N, H, W, K, 3).
"""
if self.location.ndim == points.ndim:
# pyre-fixme[7]
return self.location
# pyre-fixme[29]
return self.location[:, None, None, None, :]
def diffuse(self, normals, points) -> torch.Tensor: def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points location = self.reshape_location(points)
direction = location - points
return diffuse(normals=normals, color=self.diffuse_color, direction=direction) return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
direction = self.location - points location = self.reshape_location(points)
direction = location - points
return specular( return specular(
points=points, points=points,
normals=normals, normals=normals,

View File

@ -14,8 +14,8 @@ def _apply_lighting(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
points: torch tensor of shape (N, P, 3) or (P, 3). points: torch tensor of shape (N, ..., 3) or (P, 3).
normals: torch tensor of shape (N, P, 3) or (P, 3) normals: torch tensor of shape (N, ..., 3) or (P, 3)
lights: instance of the Lights class. lights: instance of the Lights class.
cameras: instance of the Cameras class. cameras: instance of the Cameras class.
materials: instance of the Materials class. materials: instance of the Materials class.
@ -35,6 +35,7 @@ def _apply_lighting(
ambient_color = materials.ambient_color * lights.ambient_color ambient_color = materials.ambient_color * lights.ambient_color
diffuse_color = materials.diffuse_color * light_diffuse diffuse_color = materials.diffuse_color * light_diffuse
specular_color = materials.specular_color * light_specular specular_color = materials.specular_color * light_specular
if normals.dim() == 2 and points.dim() == 2: if normals.dim() == 2 and points.dim() == 2:
# If given packed inputs remove batch dim in output. # If given packed inputs remove batch dim in output.
return ( return (
@ -42,6 +43,11 @@ def _apply_lighting(
diffuse_color.squeeze(), diffuse_color.squeeze(),
specular_color.squeeze(), specular_color.squeeze(),
) )
if ambient_color.ndim != diffuse_color.ndim:
# Reshape from (N, 3) to have dimensions compatible with
# diffuse_color which is of shape (N, H, W, K, 3)
ambient_color = ambient_color[:, None, None, None, :]
return ambient_color, diffuse_color, specular_color return ambient_color, diffuse_color, specular_color

View File

@ -6,6 +6,7 @@ Sanity checks for output images from the renderer.
""" """
import os import os
import unittest import unittest
from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
@ -53,6 +54,8 @@ DEBUG = False
DATA_DIR = get_tests_dir() / "data" DATA_DIR = get_tests_dir() / "data"
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data" TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"])
class TestRenderMeshes(TestCaseMixin, unittest.TestCase): class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False, check_depth=False): def test_simple_sphere(self, elevated_camera=False, check_depth=False):
@ -107,13 +110,13 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Test several shaders # Test several shaders
shaders = { shader_tests = [
"phong": HardPhongShader, ShaderTest(HardPhongShader, "phong", "hard_phong"),
"gouraud": HardGouraudShader, ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
"flat": HardFlatShader, ShaderTest(HardFlatShader, "flat", "hard_flat"),
} ]
for (name, shader_init) in shaders.items(): for test in shader_tests:
shader = shader_init( shader = test.shader(
lights=lights, lights=lights,
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
@ -135,7 +138,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
rgb = images[0, ..., :3].squeeze().cpu() rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s.png" % ( filename = "simple_sphere_light_%s%s%s.png" % (
name, test.reference_name,
postfix, postfix,
cam_type.__name__, cam_type.__name__,
) )
@ -144,7 +147,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(rgb, image_ref, atol=0.05) self.assertClose(rgb, image_ref, atol=0.05)
if DEBUG: if DEBUG:
filename = "DEBUG_%s" % filename debug_filename = "simple_sphere_light_%s%s%s.png" % (
test.debug_name,
postfix,
cam_type.__name__,
)
filename = "DEBUG_%s" % debug_filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename DATA_DIR / filename
) )
@ -269,7 +277,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere_batched(self): def test_simple_sphere_batched(self):
""" """
Test a mesh with vertex textures can be extended to form a batch, and Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders. is rendered correctly with Phong, Gouraud and Flat Shaders with batched
lighting and hard and soft blending.
""" """
batch_size = 5 batch_size = 5
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -291,24 +300,28 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
R, T = look_at_view_transform(dist, elev, azim) R, T = look_at_view_transform(dist, elev, azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T) cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings( raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1 image_size=512, blur_radius=0.0, faces_per_pixel=4
) )
# Init shader settings # Init shader settings
materials = Materials(device=device) materials = Materials(device=device)
lights = PointLights(device=device) lights_location = torch.tensor([0.0, 0.0, +2.0], device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] lights_location = lights_location[None].expand(batch_size, -1)
lights = PointLights(device=device, location=lights_location)
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Init renderer # Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = { shader_tests = [
"phong": HardPhongShader, ShaderTest(HardPhongShader, "phong", "hard_phong"),
"gouraud": HardGouraudShader, ShaderTest(SoftPhongShader, "phong", "soft_phong"),
"flat": HardFlatShader, ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
} ShaderTest(HardFlatShader, "flat", "hard_flat"),
for (name, shader_init) in shaders.items(): ]
shader = shader_init( for test in shader_tests:
reference_name = test.reference_name
debug_name = test.debug_name
shader = test.shader(
lights=lights, lights=lights,
cameras=cameras, cameras=cameras,
materials=materials, materials=materials,
@ -317,14 +330,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes) images = renderer(sphere_meshes)
image_ref = load_rgb_image( image_ref = load_rgb_image(
"test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__), "test_simple_sphere_light_%s_%s.png"
% (reference_name, type(cameras).__name__),
DATA_DIR, DATA_DIR,
) )
for i in range(batch_size): for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu() rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG: if i == 0 and DEBUG:
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % ( filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
name, debug_name,
type(cameras).__name__, type(cameras).__name__,
) )
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(