mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Texturing API updates
Summary: A fairly big refactor of the texturing API with some breaking changes to how textures are defined. Main changes: - There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class: - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub. - has a `join_batch` method for joining multiple textures of the same type into a batch Reviewed By: gkioxari Differential Revision: D21067427 fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b73d3d6ed9
commit
a3932960b3
@@ -11,10 +11,11 @@ import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.materials import Materials
|
||||
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
@@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.texturing import Textures
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
@@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
@@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
if DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
@@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Test a mesh with vertex textures can be extended to form a batch, and
|
||||
is rendered correctly with Phong, Gouraud and Flat Shaders.
|
||||
"""
|
||||
batch_size = 20
|
||||
batch_size = 5
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh with vertex textures.
|
||||
sphere_meshes = ico_sphere(5, device).extend(batch_size)
|
||||
verts_padded = sphere_meshes.verts_padded()
|
||||
faces_padded = sphere_meshes.faces_padded()
|
||||
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_meshes = Meshes(
|
||||
verts=verts_padded, faces=faces_padded, textures=textures
|
||||
)
|
||||
@@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
for i in range(batch_size):
|
||||
rgb = images[i, ..., :3].squeeze().cpu()
|
||||
if i == 0 and DEBUG:
|
||||
filename = "DEBUG_simple_sphere_batched_%s.png" % name
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_silhouette_with_grad(self):
|
||||
@@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
with Image.open(image_ref_filename) as raw_image_ref:
|
||||
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
||||
|
||||
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
||||
self.assertClose(alpha, image_ref, atol=0.055)
|
||||
|
||||
@@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh + texture
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
verts, faces, aux = load_obj(
|
||||
obj_filename, device=device, load_textures=True, texture_wrap=None
|
||||
)
|
||||
tex_map = list(aux.texture_images.values())[0]
|
||||
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
|
||||
textures = TexturesUV(
|
||||
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
|
||||
)
|
||||
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
@@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
|
||||
)
|
||||
joined_sphere_mesh = join_mesh(sphere_mesh_list)
|
||||
joined_sphere_mesh.textures = Textures(
|
||||
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
joined_sphere_mesh.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
@@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_texture_map_atlas(self):
|
||||
"""
|
||||
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh and texture as a per face texture atlas.
|
||||
verts, faces, aux = load_obj(
|
||||
obj_filename,
|
||||
device=device,
|
||||
load_textures=True,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=8,
|
||||
texture_wrap=None,
|
||||
)
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces.verts_idx],
|
||||
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0)
|
||||
lights = PointLights(device=device)
|
||||
|
||||
# Place light behind the cow in world space. The front of
|
||||
# the cow is facing the -z direction.
|
||||
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||
|
||||
# The HardPhongShader can be used directly with atlas textures.
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
|
||||
)
|
||||
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
Reference in New Issue
Block a user