Rendering texturing fixes

Summary:
Fix errors raised by issue on GitHub - extending mesh textures + rendering with Gourad and Phong shaders.

https://github.com/facebookresearch/pytorch3d/issues/97

Reviewed By: gkioxari

Differential Revision: D20319610

fbshipit-source-id: d1c692ff0b9397a77a9b829c5c731790de70c09f
This commit is contained in:
Nikhila Ravi
2020-03-17 08:55:57 -07:00
committed by Facebook GitHub Bot
parent f580ce1385
commit 5d3cc3569a
10 changed files with 348 additions and 152 deletions

View File

Before

Width:  |  Height:  |  Size: 21 KiB

After

Width:  |  Height:  |  Size: 21 KiB

View File

Before

Width:  |  Height:  |  Size: 10 KiB

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -135,6 +135,15 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
def test_simple(self):
mesh = TestMeshes.init_simple_mesh("cuda:0")
# Check that faces/verts per mesh are set in init:
self.assertClose(
mesh._num_faces_per_mesh.cpu(), torch.tensor([1, 2, 7])
)
self.assertClose(
mesh._num_verts_per_mesh.cpu(), torch.tensor([3, 4, 5])
)
# Check computed tensors
self.assertClose(
mesh.verts_packed_to_mesh_idx().cpu(),
torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
@@ -142,9 +151,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(
mesh.mesh_to_verts_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
)
self.assertClose(
mesh.num_verts_per_mesh().cpu(), torch.tensor([3, 4, 5])
)
self.assertClose(
mesh.verts_padded_to_packed_idx().cpu(),
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
@@ -156,9 +162,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(
mesh.mesh_to_faces_packed_first_idx().cpu(), torch.tensor([0, 1, 3])
)
self.assertClose(
mesh.num_faces_per_mesh().cpu(), torch.tensor([1, 2, 7])
)
self.assertClose(
mesh.num_edges_per_mesh().cpu(),
torch.tensor([3, 5, 10], dtype=torch.int32),
@@ -249,6 +252,8 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(mesh.faces_padded().shape[0], 0)
self.assertEqual(mesh.verts_packed().shape[0], 0)
self.assertEqual(mesh.faces_packed().shape[0], 0)
self.assertEqual(mesh.num_faces_per_mesh().shape[0], 0)
self.assertEqual(mesh.num_verts_per_mesh().shape[0], 0)
def test_empty(self):
N, V, F = 10, 100, 300
@@ -323,9 +328,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
mesh = Meshes(verts=torch.stack(verts), faces=torch.stack(faces))
# Check verts/faces per mesh are set correctly in init.
self.assertListEqual(
mesh.num_faces_per_mesh().tolist(), num_faces.tolist()
mesh._num_faces_per_mesh.tolist(), num_faces.tolist()
)
self.assertListEqual(mesh._num_verts_per_mesh.tolist(), [V] * N)
for n, (vv, ff) in enumerate(zip(mesh.verts_list(), mesh.faces_list())):
self.assertClose(ff, faces[n][: num_faces[n]])
@@ -364,7 +371,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
mesh._num_verts_per_mesh = torch.randint_like(
mesh.num_verts_per_mesh(), high=10
)
# Check cloned and original Meshes objects do not share tensors.
self.assertFalse(
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])

View File

@@ -34,7 +34,7 @@ from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
# Save out images generated in the tests for debugging
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
@@ -90,30 +90,31 @@ class TestRenderingMeshes(unittest.TestCase):
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init renderer
rasterizer = MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
)
renderer = MeshRenderer(
rasterizer=rasterizer,
shader=HardPhongShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_phong = load_rgb_image(
"test_simple_sphere_light%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_phong, atol=0.05))
# Test several shaders
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
lights=lights, cameras=cameras, materials=materials
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
########################################################
# Move the light to the +z axis in world space so it is
@@ -121,7 +122,13 @@ class TestRenderingMeshes(unittest.TestCase):
# +X left for both world and camera space.
########################################################
lights.location[..., 2] = -2.0
images = renderer(sphere_mesh, lights=lights)
phong_shader = HardPhongShader(
lights=lights, cameras=cameras, materials=materials
)
phong_renderer = MeshRenderer(
rasterizer=rasterizer, shader=phong_shader
)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
@@ -135,53 +142,6 @@ class TestRenderingMeshes(unittest.TestCase):
)
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
######################################
# Change the shader to a GouraudShader
######################################
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
renderer = MeshRenderer(
rasterizer=rasterizer,
shader=HardGouraudShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_gouraud = load_rgb_image(
"test_simple_sphere_light_gouraud%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
######################################
# Change the shader to a HardFlatShader
######################################
renderer = MeshRenderer(
rasterizer=rasterizer,
shader=HardFlatShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(sphere_mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_flat = load_rgb_image(
"test_simple_sphere_light_flat%s.png" % postfix
)
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
def test_simple_sphere_elevated_camera(self):
"""
Test output of phong and gouraud shading matches a reference image using
@@ -193,13 +153,13 @@ class TestRenderingMeshes(unittest.TestCase):
def test_simple_sphere_batched(self):
"""
Test output of phong shading matches a reference image using
the default values for the light sources.
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
"""
batch_size = 5
batch_size = 20
device = torch.device("cuda:0")
# Init mesh
# Init mesh with vertex textures.
sphere_meshes = ico_sphere(5, device).extend(batch_size)
verts_padded = sphere_meshes.verts_padded()
faces_padded = sphere_meshes.faces_padded()
@@ -224,26 +184,24 @@ class TestRenderingMeshes(unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
),
shader=HardPhongShader(
lights=lights, cameras=cameras, materials=materials
),
rasterizer = MeshRasterizer(
cameras=cameras, raster_settings=raster_settings
)
images = renderer(sphere_meshes)
# Load ref image
image_ref = load_rgb_image("test_simple_sphere_light.png")
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"DEBUG_simple_sphere_{i}.png"
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
shaders = {
"phong": HardGouraudShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(
lights=lights, cameras=cameras, materials=materials
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
def test_silhouette_with_grad(self):
"""

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import unittest
import torch
@@ -61,3 +62,28 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
example = TensorPropertiesTestClass(x=(), y=())
self.assertTrue(len(example) == 0)
self.assertTrue(example.isempty())
def test_gather_props(self):
N = 4
x = torch.randn((N, 3, 4))
y = torch.randn((N, 5))
test_class = TensorPropertiesTestClass(x=x, y=y)
S = 15
idx = torch.tensor(np.random.choice(N, S))
test_class_gathered = test_class.gather_props(idx)
self.assertTrue(test_class_gathered.x.shape == (S, 3, 4))
self.assertTrue(test_class_gathered.y.shape == (S, 5))
for i in range(N):
inds = idx == i
if inds.sum() > 0:
# Check the gathered points in the output have the same value from
# the input.
self.assertClose(
test_class_gathered.x[inds].mean(dim=0), x[i, ...]
)
self.assertClose(
test_class_gathered.y[inds].mean(dim=0), y[i, ...]
)

View File

@@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.texturing import (
interpolate_vertex_colors,
)
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures.utils import list_to_padded
from common_testing import TestCaseMixin
from test_meshes import TestMeshes
@@ -154,6 +155,108 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
torch.allclose(texels.squeeze(), expected_out.squeeze())
)
def test_init_rgb_uv_fail(self):
V = 20
# Maps has wrong shape
with self.assertRaisesRegex(ValueError, "maps"):
Textures(
maps=torch.ones((5, 16, 16, 3, 4)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
)
# faces_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "faces_uvs"):
Textures(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
)
# verts_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "verts_uvs"):
Textures(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2, 3)),
)
# verts_rgb has wrong shape
with self.assertRaisesRegex(ValueError, "verts_rgb"):
Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
# maps provided without verts/faces uvs
with self.assertRaisesRegex(
ValueError, "faces_uvs and verts_uvs are required"
):
Textures(maps=torch.ones((5, 16, 16, 3)))
def test_padded_to_packed(self):
N = 2
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
verts_uvs_padded = list_to_padded(verts_uvs_list)
tex = Textures(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_padded,
verts_uvs=verts_uvs_padded,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = (
faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
)
tex1._num_verts_per_mesh = torch.tensor([5, 4])
faces_packed = tex1.faces_uvs_packed()
verts_packed = tex1.verts_uvs_packed()
faces_list = tex1.faces_uvs_list()
verts_list = tex1.verts_uvs_list()
for f1, f2 in zip(faces_uvs_list, faces_list):
self.assertTrue((f1 == f2).all().item())
for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
idx = f.unique()
self.assertTrue((v1[idx] == v2).all().item())
self.assertTrue(faces_packed.shape == (3 + 2, 3))
# verts_packed is just flattened verts_padded.
# split sizes are not used for verts_uvs.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
# Case where num_faces_per_mesh is not set
tex2 = tex.clone()
faces_packed = tex2.faces_uvs_packed()
verts_packed = tex2.verts_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
self.assertTrue(faces_packed.shape == (3 * 2, 3))
for i in range(N):
self.assertTrue(
(faces_list[i] == faces_uvs_padded[i, ...].squeeze())
.all()
.item()
)
for i in range(N):
self.assertTrue(
(verts_list[i] == verts_uvs_padded[i, ...].squeeze())
.all()
.item()
)
def test_clone(self):
V = 20
tex = Textures(
@@ -233,13 +336,17 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
mesh = TestMeshes.init_mesh(B, 30, 50)
V = mesh._V
F = mesh._F
tex = Textures(
# 1. Texture uvs
tex_uv = Textures(
maps=torch.randn((B, 16, 16, 3)),
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
verts_uvs=torch.randn((B, V, 2)),
)
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex
verts=mesh.verts_padded(),
faces=mesh.faces_padded(),
textures=tex_uv,
)
N = 20
new_mesh = tex_mesh.extend(N)
@@ -269,5 +376,43 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
new_tex.maps_padded(),
]
)
self.assertIsNone(new_tex.verts_rgb_list())
self.assertIsNone(new_tex.verts_rgb_padded())
self.assertIsNone(new_tex.verts_rgb_packed())
# 2. Texture vertex RGB
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
tex_mesh_rgb = Meshes(
verts=mesh.verts_padded(),
faces=mesh.faces_padded(),
textures=tex_rgb,
)
N = 20
new_mesh_rgb = tex_mesh_rgb.extend(N)
self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
tex_init = tex_mesh_rgb.textures
new_tex = new_mesh_rgb.textures
for i in range(len(tex_mesh_rgb)):
for n in range(N):
self.assertClose(
tex_init.verts_rgb_list()[i],
new_tex.verts_rgb_list()[i * N + n],
)
self.assertAllSeparate(
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
)
self.assertIsNone(new_tex.verts_uvs_padded())
self.assertIsNone(new_tex.verts_uvs_list())
self.assertIsNone(new_tex.verts_uvs_packed())
self.assertIsNone(new_tex.faces_uvs_padded())
self.assertIsNone(new_tex.faces_uvs_list())
self.assertIsNone(new_tex.faces_uvs_packed())
# 3. Error
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)