Barycentric clipping in the renderer and flat shading

Summary:
Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step.

Also added support for flat shading.

Reviewed By: jcjohnson

Differential Revision: D19934259

fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
This commit is contained in:
Nikhila Ravi 2020-02-28 21:28:32 -08:00 committed by Facebook Github Bot
parent f358b9b14d
commit ff19c642cb
14 changed files with 254 additions and 108 deletions

View File

@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
return torch.flip(pixel_colors, [1])
def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
def softmax_rgb_blend(
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
proposed in [0]
@ -118,6 +120,8 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
exponential function used to control the opacity of the color.
- background_color: (3) element list/tuple/torch.Tensor specifying
the RGB values for the background color.
znear: float, near clipping plane in the z direction
zfar: float, far clipping plane in the z direction
Returns:
RGBA pixel_colors: (N, H, W, 4)
@ -125,6 +129,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
Image-based 3D Reasoning'
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pix_colors = torch.ones(
@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
delta = torch.tensor(delta, device=device)
# Near and far clipping planes.
# TODO: add zfar/znear as input params.
zfar = 100.0
znear = 1.0
# Mask for padded pixels.
mask = fragments.pix_to_face >= 0
@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
# Weights for each face. Adjust the exponential by the max z to prevent
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

View File

@ -1,5 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .texturing import ( # isort:skip
interpolate_texture_map,
interpolate_vertex_colors,
)
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer
@ -13,10 +18,6 @@ from .shader import (
TexturedSoftPhongShader,
)
from .shading import gouraud_shading, phong_shading
from .texturing import ( # isort: skip
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
)
from .utils import interpolate_face_attributes
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -5,6 +5,9 @@
import torch
import torch.nn as nn
from .rasterizer import Fragments
from .utils import _clip_barycentric_coordinates, _interpolate_zbuf
# A renderer class should be initialized with a
# function for rasterization and a function for shading.
# The rasterizer should:
@ -34,6 +37,34 @@ class MeshRenderer(nn.Module):
self.shader = shader
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Render a batch of images from a batch of meshes by rasterizing and then shading.
NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or
more barycentric coordinates lying outside the range [0, 1]. For a pixel with
out of bounds barycentric coordinates with respect to a face f, clipping is required
before interpolating the texture uv coordinates and z buffer so that the colors and
depths are limited to the range for the corresponding face.
"""
fragments = self.rasterizer(meshes_world, **kwargs)
raster_settings = kwargs.get(
"raster_settings", self.rasterizer.raster_settings
)
if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates(
fragments.bary_coords
)
clipped_zbuf = _interpolate_zbuf(
fragments.pix_to_face, clipped_bary_coords, meshes_world
)
fragments = Fragments(
bary_coords=clipped_bary_coords,
zbuf=clipped_zbuf,
dists=fragments.dists,
pix_to_face=fragments.pix_to_face,
)
images = self.shader(fragments, meshes_world, **kwargs)
return images

View File

@ -270,6 +270,7 @@ class TexturedSoftPhongShader(nn.Module):
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
@ -278,7 +279,7 @@ class TexturedSoftPhongShader(nn.Module):
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
images = softmax_rgb_blend(colors, fragments, blend_params)
return images

View File

@ -70,8 +70,12 @@ def phong_shading(
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_verts = verts[faces]
faces_normals = vertex_normals[faces]
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
pixel_coords = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts
)
pixel_normals = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_normals
)
ambient, diffuse, specular = _apply_lighting(
pixel_coords, pixel_normals, lights, cameras, materials
)
@ -122,7 +126,9 @@ def gouraud_shading(
)
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(fragments, face_colors)
colors = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_colors
)
return colors

View File

@ -7,75 +7,7 @@ import torch.nn.functional as F
from pytorch3d.structures.textures import Textures
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
"""
Args:
bary: barycentric coordinates of shape (...., 3) where `...` represents
an arbitrary number of dimensions
Returns:
bary: All barycentric coordinate values clipped to the range [0, 1]
and renormalized. The output is the same shape as the input.
"""
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % bary.shape)
clipped = bary.clamp(min=0, max=1)
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
return clipped
def interpolate_face_attributes(
fragments, face_attributes: torch.Tensor, bary_clip: bool = False
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
bary_clip: Bool to indicate if barycentric_coords should be clipped
before being used for interpolation.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
pix_to_face = fragments.pix_to_face
barycentric_coords = fragments.bary_coords
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % pix_to_face.shape)
if bary_clip:
barycentric_coords = _clip_barycentric_coordinates(barycentric_coords)
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals
from .utils import interpolate_face_attributes
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes representing a batch of meshes. It is expected that
meshes has a textures attribute which is an instance of the
Textures class.
meshes has a textures attribute which is an instance of the
Textures class.
Returns:
texels: tensor of shape (N, H, W, K, C) giving the interpolated
@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
texture_maps = meshes.textures.maps_padded()
# pixel_uvs: (N, H, W, K, 2)
pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs)
pixel_uvs = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
)
N, H_out, W_out, K = fragments.pix_to_face.shape
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
faces_packed = meshes.faces_packed()
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
texels = interpolate_face_attributes(fragments, faces_textures)
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_textures
)
return texels

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
"""
Args:
bary: barycentric coordinates of shape (...., 3) where `...` represents
an arbitrary number of dimensions
Returns:
bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0)
and renormalized. We only clip the negative values. Values > 1 will fall
into the [0, 1] range after renormalization.
The output is the same shape as the input.
"""
if bary.shape[-1] != 3:
msg = "Expected barycentric coords to have last dim = 3; got %r"
raise ValueError(msg % bary.shape)
clipped = bary.clamp(min=0.0)
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
clipped = clipped / clipped_sum
return clipped
def interpolate_face_attributes(
pix_to_face: torch.Tensor,
barycentric_coords: torch.Tensor,
face_attributes: torch.Tensor,
) -> torch.Tensor:
"""
Interpolate arbitrary face attributes using the barycentric coordinates
for each pixel in the rasterized output.
Args:
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
face_attributes: packed attributes of shape (total_faces, 3, D),
specifying the value of the attribute for each
vertex in the face.
Returns:
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
value of the face attribute for each pixel.
"""
F, FV, D = face_attributes.shape
if FV != 3:
raise ValueError("Faces can only have three vertices; got %r" % FV)
N, H, W, K, _ = barycentric_coords.shape
if pix_to_face.shape != (N, H, W, K):
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
raise ValueError(msg % pix_to_face.shape)
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = pix_to_face == -1
pix_to_face = pix_to_face.clone()
pix_to_face[mask] = 0
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
pixel_vals[mask] = 0 # Replace masked values in output.
return pixel_vals
def _interpolate_zbuf(
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
) -> torch.Tensor:
"""
A helper function to calculate the z buffer for each pixel in the
rasterized output.
Args:
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes object representing a batch of meshes.
Returns:
zbuffer: (N, H, W, K) FloatTensor
"""
verts = meshes.verts_packed()
faces = meshes.faces_packed()
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
return interpolate_face_attributes(
pix_to_face, barycentric_coords, faces_verts_z
)[
..., 0
] # (1, H, W, K)

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates
class TestMeshRenderingUtils(unittest.TestCase):
def test_bary_clip(self):
N = 10
bary = torch.randn(size=(N, 3))
# randomly make some values negative
bary[bary < 0.3] *= -1.0
# randomly make some values be greater than 1
bary[bary > 0.8] *= 2.0
negative_mask = bary < 0.0
positive_mask = bary > 1.0
clipped = _clip_barycentric_coordinates(bary)
self.assertTrue(clipped[negative_mask].sum() == 0)
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))

View File

@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.rasterizer import (
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftSilhouetteShader,
@ -99,8 +100,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@ -117,8 +119,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@ -140,8 +143,9 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
DATA_DIR / filename
)
# Load reference image
@ -149,7 +153,30 @@ class TestRenderingMeshes(unittest.TestCase):
"test_simple_sphere_light_gouraud%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005))
######################################
# Change the shader to a HardFlatShader
######################################
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
renderer = MeshRenderer(
rasterizer=rasterizer,
shader=HardFlatShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_flat = load_rgb_image(
"test_simple_sphere_light_flat%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
def test_simple_sphere_elevated_camera(self):
"""
@ -287,9 +314,6 @@ class TestRenderingMeshes(unittest.TestCase):
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init renderer
renderer = MeshRenderer(
@ -327,3 +351,32 @@ class TestRenderingMeshes(unittest.TestCase):
images = renderer(mesh2)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)
#################################
# Add blurring to rasterization
#################################
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0,
)
images = renderer(
mesh.clone(),
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))

View File

@ -8,7 +8,6 @@ import torch.nn.functional as F
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import (
_clip_barycentric_coordinates,
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
@ -94,7 +93,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(fragments, face_attributes)
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
# 2. pix_to_face must have shape (N, H, W, K)
pix_to_face = torch.ones((1, 1, 1, 1, 3))
@ -105,7 +106,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
dists=pix_to_face,
)
with self.assertRaises(ValueError):
interpolate_face_attributes(fragments, face_attributes)
interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_attributes
)
def test_interpolate_texture_map(self):
barycentric_coords = torch.tensor(
@ -220,13 +223,3 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_clip_barycentric_coords(self):
barycentric_coords = torch.tensor(
[[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32
)
expected_out = torch.tensor(
[[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32
)
clipped = _clip_barycentric_coordinates(barycentric_coords)
self.assertTrue(torch.allclose(clipped, expected_out))