pytorch3d/tests/test_texturing.py
John Reese 3b2300641a apply import merging for fbcode (11 of 11)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Reviewed By: lisroach

Differential Revision: D36402260

fbshipit-source-id: 7cb52f09b740ccc580e61e6d1787d27381a8ce00
2022-05-15 12:53:03 -07:00

1109 lines
44 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.textures import (
_list_to_padded_wrapper,
TexturesAtlas,
TexturesUV,
TexturesVertex,
)
from pytorch3d.renderer.mesh.utils import (
pack_rectangles,
pack_unique_rectangles,
Rectangle,
)
from pytorch3d.structures import list_to_packed, Meshes, packed_to_list
from test_meshes import init_mesh
def tryindex(self, index, tex, meshes, source):
tex2 = tex[index]
meshes2 = meshes[index]
tex_from_meshes = meshes2.textures
for item in source:
basic = source[item][index]
from_texture = getattr(tex2, item + "_padded")()
from_meshes = getattr(tex_from_meshes, item + "_padded")()
if isinstance(index, int):
basic = basic[None]
if len(basic) == 0:
self.assertEqual(len(from_texture), 0)
self.assertEqual(len(from_meshes), 0)
else:
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim)
item_list = getattr(tex_from_meshes, item + "_list")()
self.assertEqual(basic.shape[0], len(item_list))
for i, elem in enumerate(item_list):
self.assertClose(elem, basic[i])
class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
def test_sample_vertex_textures(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
# sample_textures calls interpolate_vertex_colors
texels = mesh.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_sample_vertex_textures_grad(self):
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
grad_vert_tex = torch.tensor(
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = mesh.sample_textures(fragments)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
def test_textures_vertex_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=torch.rand(size=(5, 10)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
verts = torch.randn((2, 4, 3), dtype=torch.float32, device=device)
faces = torch.tensor(
[[[2, 1, 0], [3, 1, 0]], [[1, 3, 0], [2, 1, 3]]],
dtype=torch.int64,
device=device,
)
# define TexturesVertex
verts_texture = torch.rand(verts.shape)
textures = TexturesVertex(verts_features=verts_texture)
# compute packed faces
ff = faces.unbind(0)
faces_packed = torch.cat([ff[0], ff[1] + verts.shape[1]])
# face verts textures
faces_verts_texts = textures.faces_verts_textures_packed(faces_packed)
verts_texts_packed = torch.cat(verts_texture.unbind(0))
faces_verts_texts_packed = verts_texts_packed[faces_packed]
self.assertClose(faces_verts_texts_packed, faces_verts_texts)
def test_submeshes(self):
# define TexturesVertex
verts_features = torch.tensor(
[
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
],
dtype=torch.float32,
)
textures = TexturesVertex(
verts_features=[verts_features, verts_features, verts_features]
)
subtextures = textures.submeshes(
[
[
torch.LongTensor([0, 2, 3]),
torch.LongTensor(list(range(8))),
],
[],
[
torch.LongTensor([4]),
],
],
None,
)
subtextures_features = subtextures.verts_features_list()
self.assertEqual(len(subtextures_features), 3)
self.assertTrue(
torch.equal(
subtextures_features[0],
torch.FloatTensor([[1, 0, 0], [1, 0, 0], [1, 0, 0]]),
)
)
self.assertTrue(torch.equal(subtextures_features[1], verts_features))
self.assertTrue(
torch.equal(subtextures_features[2], torch.FloatTensor([[0, 1, 0]]))
)
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex.verts_features_list()
tex_cloned = tex.clone()
self.assertSeparate(
tex._verts_features_padded, tex_cloned._verts_features_padded
)
self.assertClose(tex._verts_features_padded, tex_cloned._verts_features_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
self.assertClose(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
def test_detach(self):
tex = TexturesVertex(
verts_features=torch.rand(size=(10, 100, 128), requires_grad=True)
)
tex.verts_features_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._verts_features_padded.requires_grad)
self.assertClose(
tex_detached._verts_features_padded, tex._verts_features_padded
)
for i in range(tex._N):
self.assertClose(
tex._verts_features_list[i], tex_detached._verts_features_list[i]
)
self.assertFalse(tex_detached._verts_features_list[i].requires_grad)
def test_extend(self):
B = 10
mesh = init_mesh(B, 30, 50)
V = mesh._V
tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.verts_features_list()[i],
new_tex.verts_features_list()[i * N + n],
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[tex_init.verts_features_padded(), new_tex.verts_features_padded()]
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
num_verts_per_mesh = [9, 6]
D = 10
verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh]
verts_features_packed = list_to_packed(verts_features_list)[0]
verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh)
tex = TexturesVertex(verts_features=verts_features_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_features_packed()
verts_verts_list = tex1.verts_features_list()
verts_padded = tex1.verts_features_padded()
for f1, f2 in zip(verts_verts_list, verts_features_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D))
self.assertTrue(verts_padded.shape == (2, 9, D))
# Case where num_verts_per_mesh is not set and textures
# are initialized with a padded tensor.
tex2 = TexturesVertex(verts_features=verts_padded)
verts_packed = tex2.verts_features_packed()
verts_list = tex2.verts_features_list()
# Packed is just flattened padded as num_verts_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, D))
for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
source = {"verts_features": torch.randn(size=(N, V, 128))}
tex = TexturesVertex(verts_features=source["verts_features"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
def test_sample_textures_error(self):
N = 5
V = 20
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
tex = TexturesVertex(verts_features=torch.randn(size=(N, 10, 128)))
# Verts features have the wrong number of verts
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
# Verts features have the wrong batch dim
tex = TexturesVertex(verts_features=torch.randn(size=(1, V, 128)))
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
meshes = Meshes(verts=verts, faces=faces)
meshes.textures = tex
# Cannot use the texture attribute set on meshes for sampling
# textures if the dimensions don't match
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
meshes.sample_textures(None)
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
def test_sample_texture_atlas(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3))
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
)
expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32)
expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...]
expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...]
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals))
def test_textures_atlas_grad(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True)
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
grad_tex = torch.rand_like(texels)
grad_expected = torch.zeros_like(faces_atlas)
grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :]
grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :]
texels.backward(grad_tex)
self.assertTrue(hasattr(faces_atlas, "grad"))
self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected))
def test_textures_atlas_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=torch.rand(size=(5, 10, 3)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, F, R = 2, 2, 8
num_faces = torch.randint(low=1, high=F, size=(N,))
faces_atlas = [
torch.rand(size=(num_faces[i].item(), R, R, 3), device=device)
for i in range(N)
]
tex = TexturesAtlas(atlas=faces_atlas)
# faces_verts naive
faces_verts = []
for n in range(N):
ff = num_faces[n].item()
temp = torch.zeros(ff, 3, 3)
for f in range(ff):
t0 = faces_atlas[n][f, 0, -1] # for v0, bary = (1, 0)
t1 = faces_atlas[n][f, -1, 0] # for v1, bary = (0, 1)
t2 = faces_atlas[n][f, 0, 0] # for v2, bary = (0, 0)
temp[f, 0] = t0
temp[f, 1] = t1
temp[f, 2] = t2
faces_verts.append(temp)
faces_verts = torch.cat(faces_verts, 0)
self.assertClose(faces_verts, tex.faces_verts_textures_packed().cpu())
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex.atlas_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertClose(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._atlas_list[i], tex_cloned._atlas_list[i])
self.assertClose(tex._atlas_list[i], tex_cloned._atlas_list[i])
def test_detach(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3), requires_grad=True))
tex.atlas_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._atlas_padded.requires_grad)
self.assertClose(tex_detached._atlas_padded, tex._atlas_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._atlas_list[i].requires_grad)
self.assertClose(tex._atlas_list[i], tex_detached._atlas_list[i])
def test_extend(self):
B = 10
mesh = init_mesh(B, 30, 50)
F = mesh._F
tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()])
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
R = 2
N = 20
num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30)
atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh]
tex = TexturesAtlas(atlas=atlas_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh.tolist()
atlas_packed = tex1.atlas_packed()
atlas_list_new = tex1.atlas_list()
atlas_padded = tex1.atlas_padded()
for f1, f2 in zip(atlas_list_new, atlas_list):
self.assertTrue((f1 == f2).all().item())
sum_F = num_faces_per_mesh.sum()
max_F = num_faces_per_mesh.max().item()
self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3))
self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3))
# Case where num_faces_per_mesh is not set and textures
# are initialized with a padded tensor.
atlas_list_padded = _list_to_padded_wrapper(atlas_list)
tex2 = TexturesAtlas(atlas=atlas_list_padded)
atlas_packed = tex2.atlas_packed()
atlas_list_new = tex2.atlas_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3))
for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
F = 10
source = {"atlas": torch.randn(size=(N, F, 4, 4, 3))}
tex = TexturesAtlas(atlas=source["atlas"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, F, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
def test_sample_textures_error(self):
N = 1
V = 20
F = 10
verts = torch.rand(size=(5, V, 3))
faces = torch.randint(size=(5, F, 3), high=V)
meshes = Meshes(verts=verts, faces=faces)
# TexturesAtlas have the wrong batch dim
tex = TexturesAtlas(atlas=torch.randn(size=(1, F, 4, 4, 3)))
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
# TexturesAtlas have the wrong number of faces
tex = TexturesAtlas(atlas=torch.randn(size=(N, 15, 4, 4, 3)))
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
meshes = Meshes(verts=verts, faces=faces)
meshes.textures = tex
# Cannot use the texture attribute set on meshes for sampling
# textures if the dimensions don't match
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
meshes.sample_textures(None)
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def test_sample_textures_uv(self):
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
dummy_verts = torch.zeros(4, 3)
vert_uvs = torch.tensor([[1, 0], [0, 1], [1, 1], [0, 0]], dtype=torch.float32)
face_uvs = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.int64)
interpolated_uvs = torch.tensor(
[[0.5 + 0.2, 0.3 + 0.2], [0.6, 0.3 + 0.6]], dtype=torch.float32
)
# Create a dummy texture map
H = 2
W = 2
x = torch.linspace(0, 1, W).view(1, W).expand(H, W)
y = torch.linspace(0, 1, H).view(H, 1).expand(H, W)
tex_map = torch.stack([x, y], dim=2).view(1, H, W, 2)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=pix_to_face,
dists=pix_to_face,
)
for align_corners in [True, False]:
tex = TexturesUV(
maps=tex_map,
faces_uvs=[face_uvs],
verts_uvs=[vert_uvs],
align_corners=align_corners,
)
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)
# Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0
pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2)
tex_map_ = torch.cat([tex_map_, tex_map_], dim=0)
expected_out = F.grid_sample(
tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border"
)
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
def test_textures_uv_init_fail(self):
# Maps has wrong shape
with self.assertRaisesRegex(ValueError, "maps"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3, 4)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# faces_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "faces_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3)),
)
# verts has different batch dim to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(8, 15, 2)),
)
# maps has different batch dim to faces
with self.assertRaisesRegex(ValueError, "maps"):
TexturesUV(
maps=torch.ones((8, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts on different device to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"),
)
# maps on different device to faces
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3), device="cuda"),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, V, F, H, W = 2, 5, 12, 8, 8
vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
face_uvs = torch.randint(
high=V, size=(N, F, 3), dtype=torch.int64, device=device
)
maps = torch.rand((N, H, W, 3), dtype=torch.float32, device=device)
tex = TexturesUV(maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs)
# naive faces_verts_textures
faces_verts_texs = []
for n in range(N):
temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
for f in range(F):
uv0 = vert_uvs[n, face_uvs[n, f, 0]]
uv1 = vert_uvs[n, face_uvs[n, f, 1]]
uv2 = vert_uvs[n, face_uvs[n, f, 2]]
idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2
idx = idx * 2.0 - 1.0
imap = maps[n].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW
imap = torch.flip(imap, [2])
texts = torch.nn.functional.grid_sample(
imap,
idx,
align_corners=tex.align_corners,
padding_mode=tex.padding_mode,
) # 1x3x1x3
temp[f] = texts[0, :, 0, :].permute(1, 0)
faces_verts_texs.append(temp)
faces_verts_texs = torch.cat(faces_verts_texs, 0)
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._maps_padded.requires_grad)
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_extend(self):
B = 5
mesh = init_mesh(B, 30, 50)
V = mesh._V
num_faces = mesh.num_faces_per_mesh()
num_verts = mesh.num_verts_per_mesh()
faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces]
verts_uvs_list = [torch.rand(v, 2) for v in num_verts]
tex_uv = TexturesUV(
maps=torch.ones((B, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
tex_mesh = Meshes(
verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv
)
N = 2
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
new_tex_num_verts = new_mesh.num_verts_per_mesh()
for i in range(len(tex_mesh)):
for n in range(N):
tex_nv = new_tex_num_verts[i * N + n]
self.assertClose(
# The original textures were initialized using
# verts uvs list
tex_init.verts_uvs_list()[i],
# In the new textures, the verts_uvs are initialized
# from padded. The verts per mesh are not used to
# convert from padded to list. See TexturesUV for an
# explanation.
new_tex.verts_uvs_list()[i * N + n][:tex_nv, ...],
)
self.assertClose(
tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
)
self.assertClose(
tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[
tex_init.faces_uvs_padded(),
new_tex.faces_uvs_padded(),
tex_init.verts_uvs_padded(),
new_tex.verts_uvs_padded(),
tex_init.maps_padded(),
new_tex.maps_padded(),
]
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
N = 2
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
tex = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_list = tex1.verts_uvs_list()
verts_padded = tex1.verts_uvs_padded()
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()
for f1, f2 in zip(faces_list, faces_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_padded.shape == (2, 9, 2))
# Case where num_faces_per_mesh is not set and faces_verts_uvs
# are initialized with a padded tensor.
tex2 = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
verts_uvs=verts_padded,
faces_uvs=faces_padded,
)
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_to(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
def test_mesh_to(self):
tex_cpu = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
verts = torch.rand(size=(5, 15, 3))
faces = torch.randint(size=(5, 10, 3), high=15)
mesh_cpu = Meshes(faces=faces, verts=verts, textures=tex_cpu)
cpu = torch.device("cpu")
device = torch.device("cuda:0")
tex = mesh_cpu.to(device).textures
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
self.assertEqual(tex_cpu.device, cpu)
self.assertEqual(tex.device, device)
def test_getitem(self):
N = 5
V = 20
F = 10
source = {
"maps": torch.rand(size=(N, 1, 1, 3)),
"faces_uvs": torch.randint(size=(N, F, 3), high=V),
"verts_uvs": torch.randn(size=(N, V, 2)),
}
tex = TexturesUV(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
)
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, F, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
def test_centers_for_image(self):
maps = torch.rand(size=(1, 257, 129, 3))
verts_uvs = torch.FloatTensor([[[0.25, 0.125], [0.5, 0.625], [0.5, 0.5]]])
faces_uvs = torch.zeros(size=(1, 0, 3), dtype=torch.int64)
tex = TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]])
self.assertClose(tex.centers_for_image(0), expected)
def test_sample_textures_error(self):
N = 1
V = 20
F = 10
maps = torch.rand(size=(N, 1, 1, 3))
verts_uvs = torch.randn(size=(N, V, 2))
tex = TexturesUV(
maps=maps,
faces_uvs=torch.randint(size=(N, 15, 3), high=V),
verts_uvs=verts_uvs,
)
verts = torch.rand(size=(5, V, 3))
faces = torch.randint(size=(5, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces)
# Wrong number of faces
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
# Wrong batch dim for faces
tex = TexturesUV(
maps=maps,
faces_uvs=torch.randint(size=(1, F, 3), high=V),
verts_uvs=verts_uvs,
)
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
Meshes(verts=verts, faces=faces, textures=tex)
# Wrong batch dim for verts_uvs is not necessary to check as
# there is already a check inside TexturesUV for a batch dim
# mismatch with faces_uvs
meshes = Meshes(verts=verts, faces=faces)
meshes.textures = tex
# Cannot use the texture attribute set on meshes for sampling
# textures if the dimensions don't match
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
meshes.sample_textures(None)
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def wrap_pack(self, sizes):
"""
Call the pack_rectangles function, which we want to test,
and return its outputs.
Additionally makes some sanity checks on the output.
"""
res = pack_rectangles(sizes)
total = res.total_size
self.assertGreaterEqual(total[0], 0)
self.assertGreaterEqual(total[1], 0)
mask = torch.zeros(total, dtype=torch.bool)
seen_x_bound = False
seen_y_bound = False
for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip(
sizes, res.locations
):
self.assertTrue(is_first)
self.assertGreaterEqual(out_x, 0)
self.assertGreaterEqual(out_y, 0)
placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y)
upper_x = placed_x + out_x
upper_y = placed_y + out_y
self.assertGreaterEqual(total[0], upper_x)
if total[0] == upper_x:
seen_x_bound = True
self.assertGreaterEqual(total[1], upper_y)
if total[1] == upper_y:
seen_y_bound = True
already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y])
self.assertEqual(already_taken, 0)
mask[out_x:upper_x, out_y:upper_y] = 1
self.assertTrue(seen_x_bound)
self.assertTrue(seen_y_bound)
self.assertTrue(torch.all(torch.sum(mask, dim=0, dtype=torch.int32) > 0))
self.assertTrue(torch.all(torch.sum(mask, dim=1, dtype=torch.int32) > 0))
return res
def assert_bb(self, sizes, expected):
"""
Apply the pack_rectangles function to sizes and verify the
bounding box dimensions are expected.
"""
self.assertSetEqual(set(self.wrap_pack(sizes).total_size), expected)
def test_simple(self):
self.assert_bb([(3, 4), (4, 3)], {6, 4})
self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4})
# many squares
self.assert_bb([(2, 2)] * 9, {2, 18})
# One big square and many small ones.
self.assert_bb([(3, 3)] + [(1, 1)] * 2, {3, 4})
self.assert_bb([(3, 3)] + [(1, 1)] * 3, {3, 4})
self.assert_bb([(3, 3)] + [(1, 1)] * 4, {3, 5})
self.assert_bb([(3, 3)] + [(1, 1)] * 5, {3, 5})
self.assert_bb([(1, 1)] * 6 + [(3, 3)], {3, 5})
self.assert_bb([(3, 3)] + [(1, 1)] * 7, {3, 6})
# many identical rectangles
self.assert_bb([(7, 190)] * 4 + [(190, 7)] * 4, {190, 56})
# require placing the flipped version of a rectangle
self.assert_bb([(1, 100), (5, 96), (4, 5)], {100, 6})
def test_random(self):
for _ in range(5):
vals = torch.randint(size=(20, 2), low=1, high=18)
sizes = []
for j in range(vals.shape[0]):
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
self.wrap_pack(sizes)
def test_all_identical(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (61, 82))
self.assertEqual(len(locations), 3)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertEqual(x, 0)
self.assertEqual(y, 0)
self.assertFalse(is_flipped)
self.assertEqual(is_first, i == 0)
def test_one_different_id(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3
sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3)
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (82, 122))
self.assertEqual(len(locations), 6)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertTrue(is_flipped)
self.assertEqual(is_first, i % 3 == 0)
self.assertEqual(x, 0)
if i < 3:
self.assertEqual(y, 61)
else:
self.assertEqual(y, 0)