TexturesUV multiple maps

Summary: Implements the  the TexturesUV with multiple map ids.

Reviewed By: bottler

Differential Revision: D53944063

fbshipit-source-id: 06c25eb6d69f72db0484f16566dd2ca32a560b82
This commit is contained in:
Cijo Jose
2024-03-12 06:59:31 -07:00
committed by Facebook GitHub Bot
parent 7566530669
commit 38cf0dc1c5
2 changed files with 483 additions and 81 deletions

View File

@@ -718,6 +718,22 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# maps ids are not none but maps doesn't have multiple map indices
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 1, (5, 10), dtype=torch.long),
)
# maps ids is none but maps have multiple map indices
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 2, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, V, F, H, W = 2, 5, 12, 8, 8
@@ -755,6 +771,47 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_faces_verts_multiple_map_textures(self):
device = torch.device("cuda:0")
N, M, V, F, H, W = 2, 3, 5, 12, 8, 8
vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
face_uvs = torch.randint(
high=V, size=(N, F, 3), dtype=torch.int64, device=device
)
map_ids = torch.randint(0, M, (N, F), device=device)
maps = torch.rand((N, M, H, W, 3), dtype=torch.float32, device=device)
tex = TexturesUV(
maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs, maps_ids=map_ids
)
# naive faces_verts_textures
faces_verts_texs = []
for n in range(N):
temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
for f in range(F):
uv0 = vert_uvs[n, face_uvs[n, f, 0]]
uv1 = vert_uvs[n, face_uvs[n, f, 1]]
uv2 = vert_uvs[n, face_uvs[n, f, 2]]
map_id = map_ids[n, f]
idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2
idx = idx * 2.0 - 1.0
imap = maps[n, map_id].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW
imap = torch.flip(imap, [2])
texts = torch.nn.functional.grid_sample(
imap,
idx,
align_corners=tex.align_corners,
padding_mode=tex.padding_mode,
) # 1x3x1x3
temp[f] = texts[0, :, 0, :].permute(1, 0)
faces_verts_texs.append(temp)
faces_verts_texs = torch.cat(faces_verts_texs, 0)
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
@@ -781,6 +838,37 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_multiple_maps_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
self.assertSeparate(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
self.assertClose(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
for i in range(tex._N):
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertSeparate(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
self.assertClose(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
def test_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
@@ -805,6 +893,35 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_multiple_maps_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 3, 16, 16, 3), requires_grad=True),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._maps_padded.requires_grad)
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
self.assertFalse(tex_detached._maps_ids_padded.requires_grad)
self.assertClose(tex._maps_ids_padded, tex_detached._maps_ids_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
self.assertFalse(tex_detached.maps_ids_list()[i].requires_grad)
self.assertClose(tex.maps_ids_list()[i], tex_detached.maps_ids_list()[i])
def test_extend(self):
B = 5
mesh = init_mesh(B, 30, 50)
@@ -878,13 +995,15 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
maps_ids_given_list = [torch.randint(0, 3, (3,)), torch.randint(0, 3, (2,))]
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
tex = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
maps=torch.ones((N, 3, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
maps_ids=maps_ids_given_list,
)
# This is set inside Meshes when textures is passed as an input.
@@ -898,24 +1017,33 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()
maps_ids_list = tex1.maps_ids_list()
maps_ids_padded = tex1.maps_ids_padded()
for f1, f2 in zip(faces_list, faces_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(maps_ids_given_list, maps_ids_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_padded.shape == (2, 9, 2))
self.assertTrue(maps_ids_padded.shape == (2, 3))
# Case where num_faces_per_mesh is not set and faces_verts_uvs
# are initialized with a padded tensor.
tex2 = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
maps=torch.ones((N, 3, 16, 16, 3)),
verts_uvs=verts_padded,
faces_uvs=faces_padded,
maps_ids=maps_ids_padded,
)
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
maps_ids_list = tex2.maps_ids_list()
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
@@ -925,23 +1053,30 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
for i, (f1, f2) in enumerate(zip(maps_ids_list, maps_ids_given_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_to(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex._maps_ids_padded.device, device)
def test_mesh_to(self):
tex_cpu = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
verts = torch.rand(size=(5, 15, 3))
faces = torch.randint(size=(5, 10, 3), high=15)
@@ -952,24 +1087,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex._maps_ids_padded.device, device)
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
self.assertEqual(tex_cpu._maps_ids_padded.device, cpu)
self.assertEqual(tex_cpu.device, cpu)
self.assertEqual(tex.device, device)
def test_getitem(self):
N = 5
M = 3
V = 20
F = 10
source = {
"maps": torch.rand(size=(N, 1, 1, 3)),
"maps": torch.rand(size=(N, M, 1, 1, 3)),
"faces_uvs": torch.randint(size=(N, F, 3), high=V),
"verts_uvs": torch.randn(size=(N, V, 2)),
"maps_ids": torch.randint(0, M, (N, F)),
}
tex = TexturesUV(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
maps_ids=source["maps_ids"],
)
verts = torch.rand(size=(N, V, 3))