Texturing API updates

Summary:
A fairly big refactor of the texturing API with some breaking changes to how textures are defined.

Main changes:
- There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class:
   - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub.
  -  has a `join_batch` method for joining multiple textures of the same type into a batch

Reviewed By: gkioxari

Differential Revision: D21067427

fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
Nikhila Ravi
2020-07-29 16:06:58 -07:00
committed by Facebook GitHub Bot
parent b73d3d6ed9
commit a3932960b3
19 changed files with 1872 additions and 785 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import (
interpolate_face_attributes,
interpolate_face_attributes_python,
)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures import Meshes
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
@@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
def test_interpolate_attributes(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_interpolate_attributes_grad(self):
@@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
dtype=torch.float32,
requires_grad=True,
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))

View File

@@ -13,8 +13,8 @@ from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample,
_bilinear_interpolation_vectorized,
)
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.structures.meshes import join_mesh
from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.structures import Meshes, join_meshes_as_batch
from pytorch3d.utils import torus
@@ -590,17 +590,29 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
check_item(mesh.verts_padded(), mesh3.verts_padded())
check_item(mesh.faces_padded(), mesh3.faces_padded())
if mesh.textures is not None:
check_item(mesh.textures.maps_padded(), mesh3.textures.maps_padded())
check_item(
mesh.textures.faces_uvs_padded(), mesh3.textures.faces_uvs_padded()
)
check_item(
mesh.textures.verts_uvs_padded(), mesh3.textures.verts_uvs_padded()
)
check_item(
mesh.textures.verts_rgb_padded(), mesh3.textures.verts_rgb_padded()
)
if isinstance(mesh.textures, TexturesUV):
check_item(
mesh.textures.faces_uvs_padded(),
mesh3.textures.faces_uvs_padded(),
)
check_item(
mesh.textures.verts_uvs_padded(),
mesh3.textures.verts_uvs_padded(),
)
check_item(
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
)
elif isinstance(mesh.textures, TexturesVertex):
check_item(
mesh.textures.verts_features_padded(),
mesh3.textures.verts_features_padded(),
)
elif isinstance(mesh.textures, TexturesAtlas):
check_item(
mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
)
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
@@ -623,16 +635,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
check_triple(mesh_notex, mesh3_notex)
self.assertIsNone(mesh_notex.textures)
# meshes with vertex texture, join into a batch.
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
vert_tex = torch.ones_like(verts)
rgb_tex = TexturesVertex(verts_features=[vert_tex])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
# meshes with texture atlas, join into a batch.
device = "cuda:0"
atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device)
atlas_tex = TexturesAtlas(atlas=[atlas])
mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex)
mesh_atlas3 = join_meshes_as_batch([mesh_atlas, mesh_atlas, mesh_atlas])
check_triple(mesh_atlas, mesh_atlas3)
# Test load multiple meshes with textures into a batch.
teapot_obj = DATA_DIR / "teapot.obj"
mesh_teapot = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
@@ -649,41 +669,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
def test_join_meshes(self):
"""
Test that join_mesh joins single meshes and the corresponding values are
consistent with the single meshes.
"""
# Load cow mesh.
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
cow_obj = DATA_DIR / "cow_mesh/cow.obj"
cow_mesh = load_objs_as_meshes([cow_obj])
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
# Join a batch of three single meshes and check that the values are consistent
# with the individual meshes.
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
def check_item(x, y, offset):
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
# Test the joining of meshes of different sizes.
teapot_obj = DATA_DIR / "teapot.obj"
teapot_mesh = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
mix_mesh = join_mesh([cow_mesh, teapot_mesh])
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
self.assertEqual(len(mix_mesh), 1)
self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
# Check error raised if all meshes in the batch don't have the same texture type
with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):

View File

@@ -11,10 +11,11 @@ import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
@@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import (
SoftSilhouetteShader,
TexturedSoftPhongShader,
)
from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes, join_mesh
from pytorch3d.utils.ico_sphere import ico_sphere
@@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# Init rasterizer settings
@@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
@@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
"""
batch_size = 20
batch_size = 5
device = torch.device("cuda:0")
# Init mesh with vertex textures.
sphere_meshes = ico_sphere(5, device).extend(batch_size)
verts_padded = sphere_meshes.verts_padded()
faces_padded = sphere_meshes.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_meshes = Meshes(
verts=verts_padded, faces=faces_padded, textures=textures
)
@@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_simple_sphere_batched_%s.png" % name
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)
def test_silhouette_with_grad(self):
@@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
with Image.open(image_ref_filename) as raw_image_ref:
image_ref = torch.from_numpy(np.array(raw_image_ref))
image_ref = image_ref.to(dtype=torch.float32) / 255.0
self.assertClose(alpha, image_ref, atol=0.055)
@@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture
mesh = load_objs_as_meshes([obj_filename], device=device)
verts, faces, aux = load_obj(
obj_filename, device=device, load_textures=True, texture_wrap=None
)
tex_map = list(aux.texture_images.values())[0]
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
textures = TexturesUV(
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
)
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
@@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
)
joined_sphere_mesh = join_mesh(sphere_mesh_list)
joined_sphere_mesh.textures = Textures(
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
joined_sphere_mesh.textures = TexturesVertex(
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
)
# Init rasterizer settings
@@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
def test_texture_map_atlas(self):
"""
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
"""
device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh and texture as a per face texture atlas.
verts, faces, aux = load_obj(
obj_filename,
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
mesh = Meshes(
verts=[verts],
faces=[faces.verts_idx],
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
)
# Init shader settings
materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0)
lights = PointLights(device=device)
# Place light behind the cow in world space. The front of
# the cow is facing the -z direction.
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
# The HardPhongShader can be used directly with atlas textures.
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
)
self.assertClose(rgb, image_ref, atol=0.05)

View File

@@ -7,14 +7,376 @@ import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import interpolate_texture_map
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures.utils import list_to_padded
from pytorch3d.renderer.mesh.textures import (
TexturesAtlas,
TexturesUV,
TexturesVertex,
_list_to_padded_wrapper,
)
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from test_meshes import TestMeshes
class TestTexturing(TestCaseMixin, unittest.TestCase):
def test_interpolate_texture_map(self):
def tryindex(self, index, tex, meshes, source):
tex2 = tex[index]
meshes2 = meshes[index]
tex_from_meshes = meshes2.textures
for item in source:
basic = source[item][index]
from_texture = getattr(tex2, item + "_padded")()
from_meshes = getattr(tex_from_meshes, item + "_padded")()
if isinstance(index, int):
basic = basic[None]
if len(basic) == 0:
self.assertEquals(len(from_texture), 0)
self.assertEquals(len(from_meshes), 0)
else:
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim)
item_list = getattr(tex_from_meshes, item + "_list")()
self.assertEqual(basic.shape[0], len(item_list))
for i, elem in enumerate(item_list):
self.assertClose(elem, basic[i])
class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
def test_sample_vertex_textures(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
# sample_textures calls interpolate_vertex_colors
texels = mesh.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_sample_vertex_textures_grad(self):
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
grad_vert_tex = torch.tensor(
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = mesh.sample_textures(fragments)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
def test_textures_vertex_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=torch.rand(size=(5, 10)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=(1, 1, 1))
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex_cloned = tex.clone()
self.assertSeparate(
tex._verts_features_padded, tex_cloned._verts_features_padded
)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
mesh = TestMeshes.init_mesh(B, 30, 50)
V = mesh._V
tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.verts_features_list()[i],
new_tex.verts_features_list()[i * N + n],
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[tex_init.verts_features_padded(), new_tex.verts_features_padded()]
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
num_verts_per_mesh = [9, 6]
D = 10
verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh]
verts_features_packed = list_to_packed(verts_features_list)[0]
verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh)
tex = TexturesVertex(verts_features=verts_features_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_features_packed()
verts_verts_list = tex1.verts_features_list()
verts_padded = tex1.verts_features_padded()
for f1, f2 in zip(verts_verts_list, verts_features_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D))
self.assertTrue(verts_padded.shape == (2, 9, D))
# Case where num_verts_per_mesh is not set and textures
# are initialized with a padded tensor.
tex2 = TexturesVertex(verts_features=verts_padded)
verts_packed = tex2.verts_features_packed()
verts_list = tex2.verts_features_list()
# Packed is just flattened padded as num_verts_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, D))
for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
source = {"verts_features": torch.randn(size=(N, 10, 128))}
tex = TexturesVertex(verts_features=source["verts_features"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
def test_sample_texture_atlas(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3))
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
)
expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32)
expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...]
expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...]
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals))
def test_textures_atlas_grad(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True)
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
grad_tex = torch.rand_like(texels)
grad_expected = torch.zeros_like(faces_atlas)
grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :]
grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :]
texels.backward(grad_tex)
self.assertTrue(hasattr(faces_atlas, "grad"))
self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected))
def test_textures_atlas_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=torch.rand(size=(5, 10, 3)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=(1, 1, 1))
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex_cloned = tex.clone()
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
mesh = TestMeshes.init_mesh(B, 30, 50)
F = mesh._F
tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()])
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
R = 2
N = 20
num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30)
atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh]
tex = TexturesAtlas(atlas=atlas_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh.tolist()
atlas_packed = tex1.atlas_packed()
atlas_list_new = tex1.atlas_list()
atlas_padded = tex1.atlas_padded()
for f1, f2 in zip(atlas_list_new, atlas_list):
self.assertTrue((f1 == f2).all().item())
sum_F = num_faces_per_mesh.sum()
max_F = num_faces_per_mesh.max().item()
self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3))
self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3))
# Case where num_faces_per_mesh is not set and textures
# are initialized with a padded tensor.
atlas_list_padded = _list_to_padded_wrapper(atlas_list)
tex2 = TexturesAtlas(atlas=atlas_list_padded)
atlas_packed = tex2.atlas_packed()
atlas_list_new = tex2.atlas_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3))
for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
tex = TexturesAtlas(atlas=source["atlas"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
def test_sample_textures_uv(self):
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
@@ -38,11 +400,11 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
zbuf=pix_to_face,
dists=pix_to_face,
)
tex = Textures(
maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...]
)
tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
texels = interpolate_texture_map(fragments, meshes)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)
# Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0
@@ -53,190 +415,92 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
def test_init_rgb_uv_fail(self):
V = 20
def test_textures_uv_init_fail(self):
# Maps has wrong shape
with self.assertRaisesRegex(ValueError, "maps"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3, 4)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# faces_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "faces_uvs"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "verts_uvs"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2, 3)),
)
# verts_rgb has wrong shape
with self.assertRaisesRegex(ValueError, "verts_rgb"):
Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
# maps provided without verts/faces uvs
with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"):
Textures(maps=torch.ones((5, 16, 16, 3)))
def test_padded_to_packed(self):
N = 2
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
verts_uvs_padded = list_to_padded(verts_uvs_list)
tex = Textures(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_padded,
verts_uvs=verts_uvs_padded,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
tex1._num_verts_per_mesh = torch.tensor([5, 4])
faces_packed = tex1.faces_uvs_packed()
verts_packed = tex1.verts_uvs_packed()
faces_list = tex1.faces_uvs_list()
verts_list = tex1.verts_uvs_list()
for f1, f2 in zip(faces_uvs_list, faces_list):
self.assertTrue((f1 == f2).all().item())
for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
idx = f.unique()
self.assertTrue((v1[idx] == v2).all().item())
self.assertTrue(faces_packed.shape == (3 + 2, 3))
# verts_packed is just flattened verts_padded.
# split sizes are not used for verts_uvs.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
# Case where num_faces_per_mesh is not set
tex2 = tex.clone()
faces_packed = tex2.faces_uvs_packed()
verts_packed = tex2.verts_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
self.assertTrue(faces_packed.shape == (3 * 2, 3))
for i in range(N):
self.assertTrue(
(faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item()
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3)),
)
for i in range(N):
self.assertTrue(
(verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item()
# verts has different batch dim to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(8, 15, 2)),
)
# maps has different batch dim to faces
with self.assertRaisesRegex(ValueError, "maps"):
TexturesUV(
maps=torch.ones((8, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts on different device to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"),
)
# maps on different device to faces
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3), device="cuda"),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_clone(self):
V = 20
tex = Textures(
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
def test_getitem(self):
N = 5
V = 20
source = {
"maps": torch.rand(size=(N, 16, 16, 3)),
"faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
"verts_uvs": torch.rand((N, V, 2)),
}
tex = Textures(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
)
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
def tryindex(index):
tex2 = tex[index]
meshes2 = meshes[index]
tex_from_meshes = meshes2.textures
for item in source:
basic = source[item][index]
from_texture = getattr(tex2, item + "_padded")()
from_meshes = getattr(tex_from_meshes, item + "_padded")()
if isinstance(index, int):
basic = basic[None]
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
self.assertEqual(
from_texture.ndim, getattr(tex, item + "_padded")().ndim
)
if item == "faces_uvs":
faces_uvs_list = tex_from_meshes.faces_uvs_list()
self.assertEqual(basic.shape[0], len(faces_uvs_list))
for i, faces_uvs in enumerate(faces_uvs_list):
self.assertClose(faces_uvs, basic[i])
tryindex(2)
tryindex(slice(0, 2, 1))
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(index)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(index)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(index)
tryindex([2, 4])
def test_to(self):
V = 20
tex = Textures(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertTrue(tex._faces_uvs_padded.device == device)
self.assertTrue(tex._verts_uvs_padded.device == device)
self.assertTrue(tex._maps_padded.device == device)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
B = 5
mesh = TestMeshes.init_mesh(B, 30, 50)
V = mesh._V
F = mesh._F
# 1. Texture uvs
tex_uv = Textures(
maps=torch.randn((B, 16, 16, 3)),
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
verts_uvs=torch.randn((B, V, 2)),
num_faces = mesh.num_faces_per_mesh()
num_verts = mesh.num_verts_per_mesh()
faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces]
verts_uvs_list = [torch.rand(v, 2) for v in num_verts]
tex_uv = TexturesUV(
maps=torch.ones((B, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv
)
N = 20
N = 2
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
@@ -246,56 +510,142 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
)
self.assertClose(
tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
)
self.assertClose(
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[
tex_init.faces_uvs_padded(),
new_tex.faces_uvs_padded(),
tex_init.faces_uvs_packed(),
new_tex.faces_uvs_packed(),
tex_init.verts_uvs_padded(),
new_tex.verts_uvs_padded(),
tex_init.verts_uvs_packed(),
new_tex.verts_uvs_packed(),
tex_init.maps_padded(),
new_tex.maps_padded(),
]
)
self.assertIsNone(new_tex.verts_rgb_list())
self.assertIsNone(new_tex.verts_rgb_padded())
self.assertIsNone(new_tex.verts_rgb_packed())
# 2. Texture vertex RGB
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
tex_mesh_rgb = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb
)
N = 20
new_mesh_rgb = tex_mesh_rgb.extend(N)
self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
tex_init = tex_mesh_rgb.textures
new_tex = new_mesh_rgb.textures
for i in range(len(tex_mesh_rgb)):
for n in range(N):
self.assertClose(
tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n]
)
self.assertAllSeparate(
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
)
self.assertIsNone(new_tex.verts_uvs_padded())
self.assertIsNone(new_tex.verts_uvs_list())
self.assertIsNone(new_tex.verts_uvs_packed())
self.assertIsNone(new_tex.faces_uvs_padded())
self.assertIsNone(new_tex.faces_uvs_list())
self.assertIsNone(new_tex.faces_uvs_packed())
# 3. Error
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
N = 2
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
tex = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_uvs_packed()
verts_list = tex1.verts_uvs_list()
verts_padded = tex1.verts_uvs_padded()
faces_packed = tex1.faces_uvs_packed()
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()
for f1, f2 in zip(faces_list, faces_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(faces_packed.shape == (3 + 2, 3))
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_packed.shape == (9 + 6, 2))
self.assertTrue(verts_padded.shape == (2, 9, 2))
# Case where num_faces_per_mesh is not set and faces_verts_uvs
# are initialized with a padded tensor.
tex2 = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
verts_uvs=verts_padded,
faces_uvs=faces_padded,
)
faces_packed = tex2.faces_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_packed = tex2.verts_uvs_packed()
verts_list = tex2.verts_uvs_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(faces_packed.shape == (3 * 2, 3))
self.assertTrue(verts_packed.shape == (9 * 2, 2))
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_to(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertTrue(tex._faces_uvs_padded.device == device)
self.assertTrue(tex._verts_uvs_padded.device == device)
self.assertTrue(tex._maps_padded.device == device)
def test_getitem(self):
N = 5
V = 20
source = {
"maps": torch.rand(size=(N, 1, 1, 3)),
"faces_uvs": torch.randint(size=(N, 10, 3), high=V),
"verts_uvs": torch.randn(size=(N, V, 2)),
}
tex = TexturesUV(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
)
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)