add texture vertex sampling functionality to textures

Summary: Enhance every texture type with `faces_verts_textures_packed` that allows users to query the texture of each vertex in mesh

Reviewed By: nikhilaravi

Differential Revision: D24058778

fbshipit-source-id: 19d0e3a244fa96aae462c47bf52e07dfd3b7c6f0
This commit is contained in:
Georgia Gkioxari
2020-10-06 10:43:06 -07:00
committed by Facebook GitHub Bot
parent 327bd2b976
commit e651a4299c
3 changed files with 173 additions and 4 deletions

View File

@@ -112,6 +112,31 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
verts = torch.randn((2, 4, 3), dtype=torch.float32, device=device)
faces = torch.tensor(
[[[2, 1, 0], [3, 1, 0]], [[1, 3, 0], [2, 1, 3]]],
dtype=torch.int64,
device=device,
)
# define TexturesVertex
verts_texture = torch.rand(verts.shape)
textures = TexturesVertex(verts_features=verts_texture)
# compute packed faces
ff = faces.unbind(0)
faces_packed = torch.cat([ff[0], ff[1] + verts.shape[1]])
# face verts textures
faces_verts_texts = textures.faces_verts_textures_packed(faces_packed)
verts_texts_packed = torch.cat(verts_texture.unbind(0))
faces_verts_texts_packed = verts_texts_packed[faces_packed]
self.assertClose(faces_verts_texts_packed, faces_verts_texts)
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex.verts_features_list()
@@ -303,6 +328,33 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, F, R = 2, 2, 8
num_faces = torch.randint(low=1, high=F, size=(N,))
faces_atlas = [
torch.rand(size=(num_faces[i].item(), R, R, 3), device=device)
for i in range(N)
]
tex = TexturesAtlas(atlas=faces_atlas)
# faces_verts naive
faces_verts = []
for n in range(N):
ff = num_faces[n].item()
temp = torch.zeros(ff, 3, 3)
for f in range(ff):
t0 = faces_atlas[n][f, 0, -1] # for v0, bary = (1, 0)
t1 = faces_atlas[n][f, -1, 0] # for v1, bary = (0, 1)
t2 = faces_atlas[n][f, 0, 0] # for v2, bary = (0, 0)
temp[f, 0] = t0
temp[f, 1] = t1
temp[f, 2] = t2
faces_verts.append(temp)
faces_verts = torch.cat(faces_verts, 0)
self.assertClose(faces_verts, tex.faces_verts_textures_packed().cpu())
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex.atlas_list()
@@ -522,6 +574,43 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, V, F, H, W = 2, 5, 12, 8, 8
vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
face_uvs = torch.randint(
high=V, size=(N, F, 3), dtype=torch.int64, device=device
)
maps = torch.rand((N, H, W, 3), dtype=torch.float32, device=device)
tex = TexturesUV(maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs)
# naive faces_verts_textures
faces_verts_texs = []
for n in range(N):
temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
for f in range(F):
uv0 = vert_uvs[n, face_uvs[n, f, 0]]
uv1 = vert_uvs[n, face_uvs[n, f, 1]]
uv2 = vert_uvs[n, face_uvs[n, f, 2]]
idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2
idx = idx * 2.0 - 1.0
imap = maps[n].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW
imap = torch.flip(imap, [2])
texts = torch.nn.functional.grid_sample(
imap,
idx,
align_corners=tex.align_corners,
padding_mode=tex.padding_mode,
) # 1x3x1x3
temp[f] = texts[0, :, 0, :].permute(1, 0)
faces_verts_texs.append(temp)
faces_verts_texs = torch.cat(faces_verts_texs, 0)
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),