Texturing API updates

Summary:
A fairly big refactor of the texturing API with some breaking changes to how textures are defined.

Main changes:
- There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class:
   - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub.
  -  has a `join_batch` method for joining multiple textures of the same type into a batch

Reviewed By: gkioxari

Differential Revision: D21067427

fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
Nikhila Ravi
2020-07-29 16:06:58 -07:00
committed by Facebook GitHub Bot
parent b73d3d6ed9
commit a3932960b3
19 changed files with 1872 additions and 785 deletions

View File

@@ -2,7 +2,6 @@
from .meshes import Meshes, join_meshes_as_batch
from .pointclouds import Pointclouds
from .textures import Textures
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list

View File

@@ -5,7 +5,6 @@ from typing import List, Union
import torch
from . import utils as struct_utils
from .textures import Textures
class Meshes(object):
@@ -234,9 +233,9 @@ class Meshes(object):
Refer to comments above for descriptions of List and Padded representations.
"""
self.device = None
if textures is not None and not isinstance(textures, Textures):
msg = "Expected textures to be of type Textures; got %r"
raise ValueError(msg % type(textures))
if textures is not None and not repr(textures) == "TexturesBase":
msg = "Expected textures to be an instance of type TexturesBase; got %r"
raise ValueError(msg % repr(textures))
self.textures = textures
# Indicates whether the meshes in the list/batch have the same number
@@ -400,6 +399,8 @@ class Meshes(object):
if self.textures is not None:
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
self.textures._N = self._N
self.textures.valid = self.valid
def __len__(self):
return self._N
@@ -1465,6 +1466,17 @@ class Meshes(object):
return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)
def sample_textures(self, fragments):
if self.textures is not None:
# Pass in faces packed. If the textures are defined per
# vertex, the face indices are needed in order to interpolate
# the vertex attributes across the face.
return self.textures.sample_textures(
fragments, faces_packed=self.faces_packed()
)
else:
raise ValueError("Meshes does not have textures")
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
"""
@@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
raise ValueError("Inconsistent textures in join_meshes_as_batch.")
# Now we know there are multiple meshes and they have textures to merge.
first = meshes[0].textures
kwargs = {}
if first.maps_padded() is not None:
if any(mesh.textures.maps_padded() is None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
kwargs["maps"] = maps
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
all_textures = [mesh.textures for mesh in meshes]
first = all_textures[0]
tex_types_same = all(type(tex) == type(first) for tex in all_textures)
if first.verts_uvs_padded() is not None:
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
V = max(uv.shape[0] for uv in uvs)
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
if not tex_types_same:
raise ValueError("All meshes in the batch must have the same type of texture.")
if first.faces_uvs_padded() is not None:
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
F = max(uv.shape[0] for uv in uvs)
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
if first.verts_rgb_padded() is not None:
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
V = max(i.shape[0] for i in rgb)
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
tex = Textures(**kwargs)
tex = first.join_batch(all_textures[1:])
return Meshes(verts=verts, faces=faces, textures=tex)
@@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
"""
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
objects as a single mesh. If the input is a list, the Meshes objects in the list
must all be on the same device. This version ignores all textures in the input mehses.
must all be on the same device. This version ignores all textures in the input meshes.
Args:
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects

View File

@@ -1,279 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Optional, Union
import torch
from torch.nn.functional import interpolate
from .utils import padded_to_list, padded_to_packed
"""
This file has functions for interpolating textures after rasterization.
"""
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (H, W, 3)
Returns:
tex_maps: Tensor of shape (N, max_H, max_W, 3)
"""
tex_maps = []
max_H = 0
max_W = 0
for im in images:
h, w, _3 = im.shape
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_H, max_W)
for i, image in enumerate(tex_maps):
if image.shape[:2] != max_shape:
image_BCHW = image.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
)
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
return tex_maps
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
"""
Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
dimension. This is done in the following sequence of steps (where `B` is
the batch dimension):
.. code-block:: python
input_tensor (B, ...)
-> add leading empty dimension (1, B, ...)
-> expand (N, B, ...)
-> reshape (N * B, ...)
Args:
input_tensor: torch.Tensor with ndim > 2 representing a batched input.
N: number of times to extend each element of the batch.
"""
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if input_tensor.ndim < 2:
raise ValueError("Input tensor must have ndimensions >= 2.")
B = input_tensor.shape[0]
non_batch_dims = tuple(input_tensor.shape[1:])
constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded.
return (
input_tensor.clone()[None, ...]
.expand(N, *constant_dims)
.transpose(0, 1)
.reshape(N * B, *non_batch_dims)
)
class Textures(object):
def __init__(
self,
maps: Union[List, torch.Tensor, None] = None,
faces_uvs: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
verts_rgb: Optional[torch.Tensor] = None,
):
"""
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
vertex in the face. Padding value is assumed to be -1.
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
value is assumed to be -1.
Note: only the padded representation of the textures is stored
and the packed/list representations are computed on the fly and
not cached.
"""
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if faces_uvs is not None and faces_uvs.ndim != 3:
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
raise ValueError(msg % repr(faces_uvs.shape))
if verts_uvs is not None and verts_uvs.ndim != 3:
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
raise ValueError(msg % repr(verts_uvs.shape))
if verts_rgb is not None and verts_rgb.ndim != 3:
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
raise ValueError(msg % repr(verts_rgb.shape))
if maps is not None:
# pyre-fixme[16]: `List` has no attribute `ndim`.
if torch.is_tensor(maps) and maps.ndim != 4:
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
# pyre-fixme[16]: `List` has no attribute `shape`.
raise ValueError(msg % repr(maps.shape))
elif isinstance(maps, list):
maps = _pad_texture_maps(maps)
if faces_uvs is None or verts_uvs is None:
msg = "To use maps, faces_uvs and verts_uvs are required"
raise ValueError(msg)
self._faces_uvs_padded = faces_uvs
self._verts_uvs_padded = verts_uvs
self._verts_rgb_padded = verts_rgb
self._maps_padded = maps
# The number of faces/verts for each mesh is
# set inside the Meshes object when textures is
# passed into the Meshes constructor.
self._num_faces_per_mesh = None
self._num_verts_per_mesh = None
def clone(self):
other = self.__class__()
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.clone())
return other
def to(self, device):
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
return self
def __getitem__(self, index):
other = self.__class__()
for key in dir(self):
value = getattr(self, key)
if torch.is_tensor(value):
if isinstance(index, int):
setattr(other, key, value[index][None])
else:
setattr(other, key, value[index])
return other
def faces_uvs_padded(self) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`.
return self._faces_uvs_padded
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
if self._faces_uvs_padded is None:
return None
return padded_to_list(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._faces_uvs_padded,
split_size=self._num_faces_per_mesh,
)
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
if self._faces_uvs_padded is None:
return None
return padded_to_packed(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._faces_uvs_padded,
split_size=self._num_faces_per_mesh,
)
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
return self._verts_uvs_padded
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
if self._verts_uvs_padded is None:
return None
# Vertices shared between multiple faces
# may have a different uv coordinate for
# each face so the num_verts_uvs_per_mesh
# may be different from num_verts_per_mesh.
# Therefore don't use any split_size.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
return padded_to_list(self._verts_uvs_padded)
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
if self._verts_uvs_padded is None:
return None
# Vertices shared between multiple faces
# may have a different uv coordinate for
# each face so the num_verts_uvs_per_mesh
# may be different from num_verts_per_mesh.
# Therefore don't use any split_size.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
return padded_to_packed(self._verts_uvs_padded)
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
return self._verts_rgb_padded
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
if self._verts_rgb_padded is None:
return None
return padded_to_list(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._verts_rgb_padded,
split_size=self._num_verts_per_mesh,
)
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
if self._verts_rgb_padded is None:
return None
return padded_to_packed(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._verts_rgb_padded,
split_size=self._num_verts_per_mesh,
)
# Currently only the padded maps are used.
def maps_padded(self) -> Union[torch.Tensor, None]:
# pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None,
# List[typing.Any], torch.Tensor]`.
return self._maps_padded
def extend(self, N: int) -> "Textures":
"""
Create new Textures class which contains each input texture N times
Args:
N: number of new copies of each texture.
Returns:
new Textures object.
"""
if not isinstance(N, int):
raise ValueError("N must be an integer.")
if N <= 0:
raise ValueError("N must be > 0.")
if all(
v is not None
for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
):
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None,
# List[typing.Any], torch.Tensor]`.
new_maps = _extend_tensor(self._maps_padded, N)
return self.__class__(
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
)
elif self._verts_rgb_padded is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
return self.__class__(verts_rgb=new_verts_rgb)
else:
msg = "Either vertex colors or texture maps are required."
raise ValueError(msg)

View File

@@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if x.ndim != 3:
raise ValueError("Supports only 3-dimensional input tensors")
x_list = list(x.unbind(0))
if split_size is None: