Lighting broadcasting bug fix

Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests.

Reviewed By: bottler

Differential Revision: D28997941

fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
This commit is contained in:
Nikhila Ravi
2021-06-14 11:47:35 -07:00
committed by Facebook GitHub Bot
parent 9de627e01b
commit bc8361fa47
4 changed files with 73 additions and 31 deletions

View File

@@ -6,6 +6,7 @@ Sanity checks for output images from the renderer.
"""
import os
import unittest
from collections import namedtuple
import numpy as np
import torch
@@ -53,6 +54,8 @@ DEBUG = False
DATA_DIR = get_tests_dir() / "data"
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"])
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
@@ -107,13 +110,13 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Test several shaders
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]
for test in shader_tests:
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
@@ -135,7 +138,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
rgb = images[0, ..., :3].squeeze().cpu()
filename = "simple_sphere_light_%s%s%s.png" % (
name,
test.reference_name,
postfix,
cam_type.__name__,
)
@@ -144,7 +147,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(rgb, image_ref, atol=0.05)
if DEBUG:
filename = "DEBUG_%s" % filename
debug_filename = "simple_sphere_light_%s%s%s.png" % (
test.debug_name,
postfix,
cam_type.__name__,
)
filename = "DEBUG_%s" % debug_filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
@@ -269,7 +277,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere_batched(self):
"""
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
is rendered correctly with Phong, Gouraud and Flat Shaders with batched
lighting and hard and soft blending.
"""
batch_size = 5
device = torch.device("cuda:0")
@@ -291,24 +300,28 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
R, T = look_at_view_transform(dist, elev, azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
image_size=512, blur_radius=0.0, faces_per_pixel=4
)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
lights_location = torch.tensor([0.0, 0.0, +2.0], device=device)
lights_location = lights_location[None].expand(batch_size, -1)
lights = PointLights(device=device, location=lights_location)
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
# Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
shader_tests = [
ShaderTest(HardPhongShader, "phong", "hard_phong"),
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
ShaderTest(HardFlatShader, "flat", "hard_flat"),
]
for test in shader_tests:
reference_name = test.reference_name
debug_name = test.debug_name
shader = test.shader(
lights=lights,
cameras=cameras,
materials=materials,
@@ -317,14 +330,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image(
"test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__),
"test_simple_sphere_light_%s_%s.png"
% (reference_name, type(cameras).__name__),
DATA_DIR,
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
name,
debug_name,
type(cameras).__name__,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(