Barycentric clipping in the renderer and flat shading

Summary:
Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step.

Also added support for flat shading.

Reviewed By: jcjohnson

Differential Revision: D19934259

fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
This commit is contained in:
Nikhila Ravi
2020-02-28 21:28:32 -08:00
committed by Facebook Github Bot
parent f358b9b14d
commit ff19c642cb
14 changed files with 254 additions and 108 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates
class TestMeshRenderingUtils(unittest.TestCase):
def test_bary_clip(self):
N = 10
bary = torch.randn(size=(N, 3))
# randomly make some values negative
bary[bary < 0.3] *= -1.0
# randomly make some values be greater than 1
bary[bary > 0.8] *= 2.0
negative_mask = bary < 0.0
positive_mask = bary > 1.0
clipped = _clip_barycentric_coordinates(bary)
self.assertTrue(clipped[negative_mask].sum() == 0)
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))

View File

@@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.rasterizer import (
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftSilhouetteShader,
@@ -99,8 +100,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@@ -117,8 +119,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@@ -140,8 +143,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@@ -149,7 +153,30 @@ class TestRenderingMeshes(unittest.TestCase):
"test_simple_sphere_light_gouraud%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005))
######################################
# Change the shader to a HardFlatShader
######################################
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
renderer = MeshRenderer(
rasterizer=rasterizer,
shader=HardFlatShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_flat = load_rgb_image(
"test_simple_sphere_light_flat%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
def test_simple_sphere_elevated_camera(self):
"""
@@ -287,9 +314,6 @@ class TestRenderingMeshes(unittest.TestCase):
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init renderer
renderer = MeshRenderer(
@@ -327,3 +351,32 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(mesh2)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)
#################################
# Add blurring to rasterization
#################################
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0,
)
images = renderer(
mesh.clone(),
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))

View File

@@ -8,7 +8,6 @@ import torch.nn.functional as F
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import (
_clip_barycentric_coordinates,
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
@@ -94,7 +93,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(fragments, face_attributes)
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
# 2. pix_to_face must have shape (N, H, W, K)
pix_to_face = torch.ones((1, 1, 1, 1, 3))
@@ -105,7 +106,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(fragments, face_attributes)
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
def test_interpolate_texture_map(self):
barycentric_coords = torch.tensor(
@@ -220,13 +223,3 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_clip_barycentric_coords(self):
barycentric_coords = torch.tensor(
[[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32
)
expected_out = torch.tensor(
[[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32
)
clipped = _clip_barycentric_coordinates(barycentric_coords)
self.assertTrue(torch.allclose(clipped, expected_out))