mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Barycentric clipping in the renderer and flat shading
Summary: Updates to the Renderer to enable barycentric clipping. This is important when there is blurring in the rasterization step. Also added support for flat shading. Reviewed By: jcjohnson Differential Revision: D19934259 fbshipit-source-id: 036e48636cd80d28a04405d7a29fcc71a2982904
This commit is contained in:
committed by
Facebook Github Bot
parent
f358b9b14d
commit
ff19c642cb
BIN
tests/data/test_blurry_textured_rendering.png
Normal file
BIN
tests/data/test_blurry_textured_rendering.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
BIN
tests/data/test_simple_sphere_light_flat.png
Normal file
BIN
tests/data/test_simple_sphere_light_flat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
BIN
tests/data/test_simple_sphere_light_flat_elevated_camera.png
Normal file
BIN
tests/data/test_simple_sphere_light_flat_elevated_camera.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
24
tests/test_mesh_rendering_utils.py
Normal file
24
tests/test_mesh_rendering_utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.renderer.mesh.utils import _clip_barycentric_coordinates
|
||||
|
||||
|
||||
class TestMeshRenderingUtils(unittest.TestCase):
|
||||
def test_bary_clip(self):
|
||||
N = 10
|
||||
bary = torch.randn(size=(N, 3))
|
||||
# randomly make some values negative
|
||||
bary[bary < 0.3] *= -1.0
|
||||
# randomly make some values be greater than 1
|
||||
bary[bary > 0.8] *= 2.0
|
||||
negative_mask = bary < 0.0
|
||||
positive_mask = bary > 1.0
|
||||
clipped = _clip_barycentric_coordinates(bary)
|
||||
self.assertTrue(clipped[negative_mask].sum() == 0)
|
||||
self.assertTrue(clipped[positive_mask].gt(1.0).sum() == 0)
|
||||
self.assertTrue(torch.allclose(clipped.sum(dim=-1), torch.ones(N)))
|
||||
@@ -25,6 +25,7 @@ from pytorch3d.renderer.mesh.rasterizer import (
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
BlendParams,
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
HardPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
@@ -99,8 +100,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_light%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@@ -117,8 +119,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_dark%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@@ -140,8 +143,9 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
@@ -149,7 +153,30 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
"test_simple_sphere_light_gouraud%s.png" % postfix
|
||||
)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
|
||||
self.assertFalse(torch.allclose(rgb, image_ref_phong, atol=0.005))
|
||||
|
||||
######################################
|
||||
# Change the shader to a HardFlatShader
|
||||
######################################
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=rasterizer,
|
||||
shader=HardFlatShader(
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
images = renderer(sphere_mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
image_ref_flat = load_rgb_image(
|
||||
"test_simple_sphere_light_flat%s.png" % postfix
|
||||
)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
|
||||
|
||||
def test_simple_sphere_elevated_camera(self):
|
||||
"""
|
||||
@@ -287,9 +314,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
)
|
||||
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
@@ -327,3 +351,32 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
images = renderer(mesh2)
|
||||
images[0, ...].sum().backward()
|
||||
self.assertIsNotNone(verts.grad)
|
||||
|
||||
#################################
|
||||
# Add blurring to rasterization
|
||||
#################################
|
||||
|
||||
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=100,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
images = renderer(
|
||||
mesh.clone(),
|
||||
raster_settings=raster_settings,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.texturing import (
|
||||
_clip_barycentric_coordinates,
|
||||
interpolate_face_attributes,
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
@@ -94,7 +93,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(fragments, face_attributes)
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
|
||||
# 2. pix_to_face must have shape (N, H, W, K)
|
||||
pix_to_face = torch.ones((1, 1, 1, 1, 3))
|
||||
@@ -105,7 +106,9 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
dists=pix_to_face,
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
interpolate_face_attributes(fragments, face_attributes)
|
||||
interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_attributes
|
||||
)
|
||||
|
||||
def test_interpolate_texture_map(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
@@ -220,13 +223,3 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
tex_mesh.extend(N=-1)
|
||||
|
||||
def test_clip_barycentric_coords(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
[[1.5, -0.3, -0.2], [1.2, 0.3, -0.5]], dtype=torch.float32
|
||||
)
|
||||
expected_out = torch.tensor(
|
||||
[[1.0, 0.0, 0.0], [1.0 / 1.3, 0.3 / 1.3, 0.0]], dtype=torch.float32
|
||||
)
|
||||
clipped = _clip_barycentric_coordinates(barycentric_coords)
|
||||
self.assertTrue(torch.allclose(clipped, expected_out))
|
||||
|
||||
Reference in New Issue
Block a user