mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Barycentric clipping in the renderer and flat shading
Summary: Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step. Also added support for flat shading. Reviewed By: jcjohnson Differential Revision: D19934259 fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
This commit is contained in:
parent
f358b9b14d
commit
ff19c642cb
@ -90,7 +90,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
return torch.flip(pixel_colors, [1])
|
||||
|
||||
|
||||
def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
def softmax_rgb_blend(
|
||||
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
RGB and alpha channel blending to return an RGBA image based on the method
|
||||
proposed in [0]
|
||||
@ -118,6 +120,8 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
exponential function used to control the opacity of the color.
|
||||
- background_color: (3) element list/tuple/torch.Tensor specifying
|
||||
the RGB values for the background color.
|
||||
znear: float, near clipping plane in the z direction
|
||||
zfar: float, far clipping plane in the z direction
|
||||
|
||||
Returns:
|
||||
RGBA pixel_colors: (N, H, W, 4)
|
||||
@ -125,6 +129,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
|
||||
Image-based 3D Reasoning'
|
||||
"""
|
||||
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
pix_colors = torch.ones(
|
||||
@ -140,11 +145,6 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
|
||||
delta = torch.tensor(delta, device=device)
|
||||
|
||||
# Near and far clipping planes.
|
||||
# TODO: add zfar/znear as input params.
|
||||
zfar = 100.0
|
||||
znear = 1.0
|
||||
|
||||
# Mask for padded pixels.
|
||||
mask = fragments.pix_to_face >= 0
|
||||
|
||||
@ -164,6 +164,7 @@ def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
# Weights for each face. Adjust the exponential by the max z to prevent
|
||||
# overflow. zbuf shape (N, H, W, K), find max over K.
|
||||
# TODO: there may still be some instability in the exponent calculation.
|
||||
|
||||
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
|
||||
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
|
||||
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
|
||||
|
@ -1,5 +1,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .texturing import ( # isort:skip
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
)
|
||||
from .rasterize_meshes import rasterize_meshes
|
||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from .renderer import MeshRenderer
|
||||
@ -13,10 +18,6 @@ from .shader import (
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from .shading import gouraud_shading, phong_shading
|
||||
from .texturing import ( # isort: skip
|
||||
interpolate_face_attributes,
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
)
|
||||
from .utils import interpolate_face_attributes
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -5,6 +5,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .rasterizer import Fragments
|
||||
from .utils import _clip_barycentric_coordinates, _interpolate_zbuf
|
||||
|
||||
# A renderer class should be initialized with a
|
||||
# function for rasterization and a function for shading.
|
||||
# The rasterizer should:
|
||||
@ -34,6 +37,34 @@ class MeshRenderer(nn.Module):
|
||||
self.shader = shader
|
||||
|
||||
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Render a batch of images from a batch of meshes by rasterizing and then shading.
|
||||
|
||||
NOTE: If the blur radius for rasterization is > 0.0, some pixels can have one or
|
||||
more barycentric coordinates lying outside the range [0, 1]. For a pixel with
|
||||
out of bounds barycentric coordinates with respect to a face f, clipping is required
|
||||
before interpolating the texture uv coordinates and z buffer so that the colors and
|
||||
depths are limited to the range for the corresponding face.
|
||||
"""
|
||||
fragments = self.rasterizer(meshes_world, **kwargs)
|
||||
raster_settings = kwargs.get(
|
||||
"raster_settings", self.rasterizer.raster_settings
|
||||
)
|
||||
if raster_settings.blur_radius > 0.0:
|
||||
# TODO: potentially move barycentric clipping to the rasterizer
|
||||
# if no downstream functions requires unclipped values.
|
||||
# This will avoid unnecssary re-interpolation of the z buffer.
|
||||
clipped_bary_coords = _clip_barycentric_coordinates(
|
||||
fragments.bary_coords
|
||||
)
|
||||
clipped_zbuf = _interpolate_zbuf(
|
||||
fragments.pix_to_face, clipped_bary_coords, meshes_world
|
||||
)
|
||||
fragments = Fragments(
|
||||
bary_coords=clipped_bary_coords,
|
||||
zbuf=clipped_zbuf,
|
||||
dists=fragments.dists,
|
||||
pix_to_face=fragments.pix_to_face,
|
||||
)
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
return images
|
||||
|
@ -270,6 +270,7 @@ class TexturedSoftPhongShader(nn.Module):
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
colors = phong_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@ -278,7 +279,7 @@ class TexturedSoftPhongShader(nn.Module):
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = softmax_rgb_blend(colors, fragments, self.blend_params)
|
||||
images = softmax_rgb_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
|
@ -70,8 +70,12 @@ def phong_shading(
|
||||
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
||||
faces_verts = verts[faces]
|
||||
faces_normals = vertex_normals[faces]
|
||||
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
|
||||
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
|
||||
pixel_coords = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_verts
|
||||
)
|
||||
pixel_normals = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_normals
|
||||
)
|
||||
ambient, diffuse, specular = _apply_lighting(
|
||||
pixel_coords, pixel_normals, lights, cameras, materials
|
||||
)
|
||||
@ -122,7 +126,9 @@ def gouraud_shading(
|
||||
)
|
||||
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
|
||||
face_colors = verts_colors_shaded[faces]
|
||||
colors = interpolate_face_attributes(fragments, face_colors)
|
||||
colors = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_colors
|
||||
)
|
||||
return colors
|
||||
|
||||
|
||||
|
@ -7,75 +7,7 @@ import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.structures.textures import Textures
|
||||
|
||||
|
||||
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
bary: barycentric coordinates of shape (...., 3) where `...` represents
|
||||
an arbitrary number of dimensions
|
||||
|
||||
Returns:
|
||||
bary: All barycentric coordinate values clipped to the range [0, 1]
|
||||
and renormalized. The output is the same shape as the input.
|
||||
"""
|
||||
if bary.shape[-1] != 3:
|
||||
msg = "Expected barycentric coords to have last dim = 3; got %r"
|
||||
raise ValueError(msg % bary.shape)
|
||||
clipped = bary.clamp(min=0, max=1)
|
||||
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
|
||||
clipped = clipped / clipped_sum
|
||||
return clipped
|
||||
|
||||
|
||||
def interpolate_face_attributes(
|
||||
fragments, face_attributes: torch.Tensor, bary_clip: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate arbitrary face attributes using the barycentric coordinates
|
||||
for each pixel in the rasterized output.
|
||||
|
||||
Args:
|
||||
fragments:
|
||||
The outputs of rasterization. From this we use
|
||||
|
||||
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
face_attributes: packed attributes of shape (total_faces, 3, D),
|
||||
specifying the value of the attribute for each
|
||||
vertex in the face.
|
||||
bary_clip: Bool to indicate if barycentric_coords should be clipped
|
||||
before being used for interpolation.
|
||||
|
||||
Returns:
|
||||
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
|
||||
value of the face attribute for each pixel.
|
||||
"""
|
||||
pix_to_face = fragments.pix_to_face
|
||||
barycentric_coords = fragments.bary_coords
|
||||
F, FV, D = face_attributes.shape
|
||||
if FV != 3:
|
||||
raise ValueError("Faces can only have three vertices; got %r" % FV)
|
||||
N, H, W, K, _ = barycentric_coords.shape
|
||||
if pix_to_face.shape != (N, H, W, K):
|
||||
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
|
||||
raise ValueError(msg % pix_to_face.shape)
|
||||
if bary_clip:
|
||||
barycentric_coords = _clip_barycentric_coordinates(barycentric_coords)
|
||||
|
||||
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
|
||||
mask = pix_to_face == -1
|
||||
pix_to_face = pix_to_face.clone()
|
||||
pix_to_face[mask] = 0
|
||||
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
||||
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
||||
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
||||
pixel_vals[mask] = 0 # Replace masked values in output.
|
||||
return pixel_vals
|
||||
from .utils import interpolate_face_attributes
|
||||
|
||||
|
||||
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
|
||||
@ -97,8 +29,8 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
meshes: Meshes representing a batch of meshes. It is expected that
|
||||
meshes has a textures attribute which is an instance of the
|
||||
Textures class.
|
||||
meshes has a textures attribute which is an instance of the
|
||||
Textures class.
|
||||
|
||||
Returns:
|
||||
texels: tensor of shape (N, H, W, K, C) giving the interpolated
|
||||
@ -114,7 +46,9 @@ def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
|
||||
texture_maps = meshes.textures.maps_padded()
|
||||
|
||||
# pixel_uvs: (N, H, W, K, 2)
|
||||
pixel_uvs = interpolate_face_attributes(fragments, faces_verts_uvs)
|
||||
pixel_uvs = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
|
||||
)
|
||||
|
||||
N, H_out, W_out, K = fragments.pix_to_face.shape
|
||||
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
|
||||
@ -178,5 +112,7 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
|
||||
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
|
||||
faces_packed = meshes.faces_packed()
|
||||
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
|
||||
texels = interpolate_face_attributes(fragments, faces_textures)
|
||||
texels = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_textures
|
||||
)
|
||||
return texels
|
||||
|
100
pytorch3d/renderer/mesh/utils.py
Normal file
100
pytorch3d/renderer/mesh/utils.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _clip_barycentric_coordinates(bary) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
bary: barycentric coordinates of shape (...., 3) where `...` represents
|
||||
an arbitrary number of dimensions
|
||||
|
||||
Returns:
|
||||
bary: Barycentric coordinates clipped (i.e any values < 0 are set to 0)
|
||||
and renormalized. We only clip the negative values. Values > 1 will fall
|
||||
into the [0, 1] range after renormalization.
|
||||
The output is the same shape as the input.
|
||||
"""
|
||||
if bary.shape[-1] != 3:
|
||||
msg = "Expected barycentric coords to have last dim = 3; got %r"
|
||||
raise ValueError(msg % bary.shape)
|
||||
clipped = bary.clamp(min=0.0)
|
||||
clipped_sum = torch.clamp(clipped.sum(dim=-1, keepdim=True), min=1e-5)
|
||||
clipped = clipped / clipped_sum
|
||||
return clipped
|
||||
|
||||
|
||||
def interpolate_face_attributes(
|
||||
pix_to_face: torch.Tensor,
|
||||
barycentric_coords: torch.Tensor,
|
||||
face_attributes: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate arbitrary face attributes using the barycentric coordinates
|
||||
for each pixel in the rasterized output.
|
||||
|
||||
Args:
|
||||
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
face_attributes: packed attributes of shape (total_faces, 3, D),
|
||||
specifying the value of the attribute for each
|
||||
vertex in the face.
|
||||
|
||||
Returns:
|
||||
pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated
|
||||
value of the face attribute for each pixel.
|
||||
"""
|
||||
F, FV, D = face_attributes.shape
|
||||
if FV != 3:
|
||||
raise ValueError("Faces can only have three vertices; got %r" % FV)
|
||||
N, H, W, K, _ = barycentric_coords.shape
|
||||
if pix_to_face.shape != (N, H, W, K):
|
||||
msg = "pix_to_face must have shape (batch_size, H, W, K); got %r"
|
||||
raise ValueError(msg % pix_to_face.shape)
|
||||
|
||||
# Replace empty pixels in pix_to_face with 0 in order to interpolate.
|
||||
mask = pix_to_face == -1
|
||||
pix_to_face = pix_to_face.clone()
|
||||
pix_to_face[mask] = 0
|
||||
idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D)
|
||||
pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D)
|
||||
pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2)
|
||||
pixel_vals[mask] = 0 # Replace masked values in output.
|
||||
return pixel_vals
|
||||
|
||||
|
||||
def _interpolate_zbuf(
|
||||
pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A helper function to calculate the z buffer for each pixel in the
|
||||
rasterized output.
|
||||
|
||||
Args:
|
||||
pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
meshes: Meshes object representing a batch of meshes.
|
||||
|
||||
Returns:
|
||||
zbuffer: (N, H, W, K) FloatTensor
|
||||
"""
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1)
|
||||
return interpolate_face_attributes(
|
||||
pix_to_face, barycentric_coords, faces_verts_z
|
||||
)[
|
||||
..., 0
|
||||
] # (1, H, W, K)
|
BIN
tests/data/test_blurry_textured_rendering.png
Normal file
BIN
tests/data/test_blurry_textured_rendering.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 45 KiB |
BIN
tests/data/test_simple_sphere_light_flat.png
Normal file
BIN
tests/data/test_simple_sphere_light_flat.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 26 KiB |
BIN
tests/data/test_simple_sphere_light_flat_elevated_camera.png
Normal file
BIN
tests/data/test_simple_sphere_light_flat_elevated_camera.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
24
tests/test_mesh_rendering_utils.py
Normal file
24
tests/test_mesh_rendering_utils.py
Normal file
@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates
|
||||
|
||||
|
||||
class TestMeshRenderingUtils(unittest.TestCase):
|
||||
def test_bary_clip(self):
|
||||
N = 10
|
||||
bary = torch.randn(size=(N, 3))
|
||||
# randomly make some values negative
|
||||
bary[bary < 0.3] *= -1.0
|
||||
# randomly make some values be greater than 1
|
||||
bary[bary > 0.8] *= 2.0
|
||||
negative_mask = bary < 0.0
|
||||
positive_mask = bary > 1.0
|
||||
clipped = _clip_barycentric_coordinates(bary)
|
||||
self.assertTrue(clipped[negative_mask].sum() == 0)
|
||||
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
|
||||
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))
|
@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.rasterizer import (
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
@ -99,8 +100,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@ -117,8 +119,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@ -140,8 +143,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@ -149,7 +153,30 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
"test_simple_sphere_light_gouraud%s.png" % postfix
|
||||
)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
|
||||
self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005))
|
||||
|
||||
######################################
|
||||
# Change the shader to a HardFlatShader
|
||||
######################################
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=rasterizer,
|
||||
shader=HardFlatShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
image_ref_flat = load_rgb_image(
|
||||
"test_simple_sphere_light_flat%s.png" % postfix
|
||||
)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
|
||||
|
||||
def test_simple_sphere_elevated_camera(self):
|
||||
"""
|
||||
@ -287,9 +314,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
)
|
||||
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
@ -327,3 +351,32 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(mesh2)
|
||||
images[0, ...].sum().backward()
|
||||
self.assertIsNotNone(verts.grad)
|
||||
|
||||
#################################
|
||||
# Add blurring to rasterization
|
||||
#################################
|
||||
|
||||
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=100,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
images = renderer(
|
||||
mesh.clone(),
|
||||
raster_settings=raster_settings,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.texturing import (
|
||||
_clip_barycentric_coordinates,
|
||||
interpolate_face_attributes,
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
@ -94,7 +93,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(fragments, face_attributes)
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
|
||||
# 2. pix_to_face must have shape (N, H, W, K)
|
||||
pix_to_face = torch.ones((1, 1, 1, 1, 3))
|
||||
@ -105,7 +106,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(fragments, face_attributes)
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
|
||||
def test_interpolate_texture_map(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
@ -220,13 +223,3 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
tex_mesh.extend(N=-1)
|
||||
|
||||
def test_clip_barycentric_coords(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
[[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32
|
||||
)
|
||||
expected_out = torch.tensor(
|
||||
[[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32
|
||||
)
|
||||
clipped = _clip_barycentric_coordinates(barycentric_coords)
|
||||
self.assertTrue(torch.allclose(clipped, expected_out))
|
||||
|
Loading…
x
Reference in New Issue
Block a user