add texture vertex sampling functionality to textures

Summary: Enhance every texture type with `faces_verts_textures_packed` that allows users to query the texture of each vertex in mesh

Reviewed By: nikhilaravi

Differential Revision: D24058778

fbshipit-source-id: 19d0e3a244fa96aae462c47bf52e07dfd3b7c6f0
This commit is contained in:
Georgia Gkioxari 2020-10-06 10:43:06 -07:00 committed by Facebook GitHub Bot
parent 327bd2b976
commit e651a4299c
3 changed files with 173 additions and 4 deletions

View File

@ -129,7 +129,7 @@ def make_material_atlas(
the formulation from [1].
For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system
with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0).
with the x axis being the vector (v0 - v2) and the y axis being the vector (v1 - v2).
The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates
a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0).

View File

@ -236,6 +236,17 @@ class TexturesBase:
"""
raise NotImplementedError()
def faces_verts_textures_packed(self):
"""
Returns the texture for each vertex for each face in the mesh.
For N meshes, this function returns sum(Fi)x3xC where Fi is the
number of faces in the i-th mesh and C is the dimensional of
the feature (C = 3 for RGB textures).
You can use the utils function in structures.utils to convert the
packed respresentation to a list or padded.
"""
raise NotImplementedError()
def clone(self):
"""
Each texture class should implement a method
@ -286,7 +297,7 @@ def Textures(
Returns:
a Textures class which is an instance of TexturesBase e.g. TexturesUV,
TexturesAtlas, TexturesVerte
TexturesAtlas, TexturesVertex
"""
@ -507,6 +518,23 @@ class TexturesAtlas(TexturesBase):
return texels
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
Samples texture from each vertex for each face in the mesh.
For N meshes with {Fi} number of faces, it returns a
tensor of shape sum(Fi)x3xD (D = 3 for RGB).
You can use the utils function in structures.utils to convert the
packed respresentation to a list or padded.
"""
atlas_packed = self.atlas_packed()
# assume each face consists of (v0, v1, v2).
# to sample from the atlas we only need the first two barycentric coordinates.
# for details on how this texture sample works refer to the sample_textures function.
t0 = atlas_packed[:, 0, -1] # corresponding to v0 with bary = (1, 0)
t1 = atlas_packed[:, -1, 0] # corresponding to v1 with bary = (0, 1)
t2 = atlas_packed[:, 0, 0] # corresponding to v2 with bary = (0, 0)
return torch.stack((t0, t1, t2), dim=1)
def join_batch(self, textures: List["TexturesAtlas"]) -> "TexturesAtlas":
"""
Join the list of textures given by `textures` to
@ -514,10 +542,10 @@ class TexturesAtlas(TexturesBase):
TexturesAtlas object with the combined textures.
Args:
textures: List of TextureAtlas objects
textures: List of TexturesAtlas objects
Returns:
new_tex: TextureAtlas object with the combined
new_tex: TexturesAtlas object with the combined
textures from self and the list `textures`.
"""
tex_types_same = all(isinstance(tex, TexturesAtlas) for tex in textures)
@ -917,6 +945,46 @@ class TexturesUV(TexturesBase):
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
Samples texture from each vertex and for each face in the mesh.
For N meshes with {Fi} number of faces, it returns a
tensor of shape sum(Fi)x3xC (C = 3 for RGB).
You can use the utils function in structures.utils to convert the
packed representation to a list or padded.
"""
if self.isempty():
return torch.zeros(
(0, 3, self.maps_padded().shape[-1]),
dtype=torch.float32,
device=self.device,
)
else:
packing_list = [
i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
]
faces_verts_uvs = _list_to_padded_wrapper(
packing_list, pad_value=0.0
) # Nxmax(Fi)x3x2
texture_maps = self.maps_padded() # NxHxWxC
texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW
faces_verts_uvs = faces_verts_uvs * 2.0 - 1.0
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
textures = F.grid_sample(
texture_maps,
faces_verts_uvs,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
) # NxCxmax(Fi)x3
textures = textures.permute(0, 2, 3, 1) # Nxmax(Fi)x3xC
textures = _padded_to_list_wrapper(
textures, split_size=self._num_faces_per_mesh
) # list of N {Fix3xC} tensors
return list_to_packed(textures)[0]
def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
"""
Join the list of textures given by `textures` to
@ -1268,6 +1336,18 @@ class TexturesVertex(TexturesBase):
)
return texels
def faces_verts_textures_packed(self, faces_packed=None) -> torch.Tensor:
"""
Samples texture from each vertex and for each face in the mesh.
For N meshes with {Fi} number of faces, it returns a
tensor of shape sum(Fi)x3xC (C = 3 for RGB).
You can use the utils function in structures.utils to convert the
packed respresentation to a list or padded.
"""
verts_features_packed = self.verts_features_packed()
faces_verts_features = verts_features_packed[faces_packed]
return faces_verts_features
def join_batch(self, textures: List["TexturesVertex"]) -> "TexturesVertex":
"""
Join the list of textures given by `textures` to

View File

@ -112,6 +112,31 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
verts = torch.randn((2, 4, 3), dtype=torch.float32, device=device)
faces = torch.tensor(
[[[2, 1, 0], [3, 1, 0]], [[1, 3, 0], [2, 1, 3]]],
dtype=torch.int64,
device=device,
)
# define TexturesVertex
verts_texture = torch.rand(verts.shape)
textures = TexturesVertex(verts_features=verts_texture)
# compute packed faces
ff = faces.unbind(0)
faces_packed = torch.cat([ff[0], ff[1] + verts.shape[1]])
# face verts textures
faces_verts_texts = textures.faces_verts_textures_packed(faces_packed)
verts_texts_packed = torch.cat(verts_texture.unbind(0))
faces_verts_texts_packed = verts_texts_packed[faces_packed]
self.assertClose(faces_verts_texts_packed, faces_verts_texts)
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex.verts_features_list()
@ -303,6 +328,33 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=(1, 1, 1))
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, F, R = 2, 2, 8
num_faces = torch.randint(low=1, high=F, size=(N,))
faces_atlas = [
torch.rand(size=(num_faces[i].item(), R, R, 3), device=device)
for i in range(N)
]
tex = TexturesAtlas(atlas=faces_atlas)
# faces_verts naive
faces_verts = []
for n in range(N):
ff = num_faces[n].item()
temp = torch.zeros(ff, 3, 3)
for f in range(ff):
t0 = faces_atlas[n][f, 0, -1] # for v0, bary = (1, 0)
t1 = faces_atlas[n][f, -1, 0] # for v1, bary = (0, 1)
t2 = faces_atlas[n][f, 0, 0] # for v2, bary = (0, 0)
temp[f, 0] = t0
temp[f, 1] = t1
temp[f, 2] = t2
faces_verts.append(temp)
faces_verts = torch.cat(faces_verts, 0)
self.assertClose(faces_verts, tex.faces_verts_textures_packed().cpu())
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex.atlas_list()
@ -522,6 +574,43 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, V, F, H, W = 2, 5, 12, 8, 8
vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
face_uvs = torch.randint(
high=V, size=(N, F, 3), dtype=torch.int64, device=device
)
maps = torch.rand((N, H, W, 3), dtype=torch.float32, device=device)
tex = TexturesUV(maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs)
# naive faces_verts_textures
faces_verts_texs = []
for n in range(N):
temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
for f in range(F):
uv0 = vert_uvs[n, face_uvs[n, f, 0]]
uv1 = vert_uvs[n, face_uvs[n, f, 1]]
uv2 = vert_uvs[n, face_uvs[n, f, 2]]
idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2
idx = idx * 2.0 - 1.0
imap = maps[n].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW
imap = torch.flip(imap, [2])
texts = torch.nn.functional.grid_sample(
imap,
idx,
align_corners=tex.align_corners,
padding_mode=tex.padding_mode,
) # 1x3x1x3
temp[f] = texts[0, :, 0, :].permute(1, 0)
faces_verts_texs.append(temp)
faces_verts_texs = torch.cat(faces_verts_texs, 0)
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),