mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
Texturing API updates
Summary: A fairly big refactor of the texturing API with some breaking changes to how textures are defined. Main changes: - There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class: - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub. - has a `join_batch` method for joining multiple textures of the same type into a batch Reviewed By: gkioxari Differential Revision: D21067427 fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b73d3d6ed9
commit
a3932960b3
@@ -2,7 +2,6 @@
|
||||
|
||||
from .meshes import Meshes, join_meshes_as_batch
|
||||
from .pointclouds import Pointclouds
|
||||
from .textures import Textures
|
||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import List, Union
|
||||
import torch
|
||||
|
||||
from . import utils as struct_utils
|
||||
from .textures import Textures
|
||||
|
||||
|
||||
class Meshes(object):
|
||||
@@ -234,9 +233,9 @@ class Meshes(object):
|
||||
Refer to comments above for descriptions of List and Padded representations.
|
||||
"""
|
||||
self.device = None
|
||||
if textures is not None and not isinstance(textures, Textures):
|
||||
msg = "Expected textures to be of type Textures; got %r"
|
||||
raise ValueError(msg % type(textures))
|
||||
if textures is not None and not repr(textures) == "TexturesBase":
|
||||
msg = "Expected textures to be an instance of type TexturesBase; got %r"
|
||||
raise ValueError(msg % repr(textures))
|
||||
self.textures = textures
|
||||
|
||||
# Indicates whether the meshes in the list/batch have the same number
|
||||
@@ -400,6 +399,8 @@ class Meshes(object):
|
||||
if self.textures is not None:
|
||||
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
||||
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
||||
self.textures._N = self._N
|
||||
self.textures.valid = self.valid
|
||||
|
||||
def __len__(self):
|
||||
return self._N
|
||||
@@ -1465,6 +1466,17 @@ class Meshes(object):
|
||||
|
||||
return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
||||
|
||||
def sample_textures(self, fragments):
|
||||
if self.textures is not None:
|
||||
# Pass in faces packed. If the textures are defined per
|
||||
# vertex, the face indices are needed in order to interpolate
|
||||
# the vertex attributes across the face.
|
||||
return self.textures.sample_textures(
|
||||
fragments, faces_packed=self.faces_packed()
|
||||
)
|
||||
else:
|
||||
raise ValueError("Meshes does not have textures")
|
||||
|
||||
|
||||
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
"""
|
||||
@@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
raise ValueError("Inconsistent textures in join_meshes_as_batch.")
|
||||
|
||||
# Now we know there are multiple meshes and they have textures to merge.
|
||||
first = meshes[0].textures
|
||||
kwargs = {}
|
||||
if first.maps_padded() is not None:
|
||||
if any(mesh.textures.maps_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
|
||||
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
|
||||
kwargs["maps"] = maps
|
||||
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
|
||||
all_textures = [mesh.textures for mesh in meshes]
|
||||
first = all_textures[0]
|
||||
tex_types_same = all(type(tex) == type(first) for tex in all_textures)
|
||||
|
||||
if first.verts_uvs_padded() is not None:
|
||||
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
|
||||
V = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
|
||||
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
|
||||
if not tex_types_same:
|
||||
raise ValueError("All meshes in the batch must have the same type of texture.")
|
||||
|
||||
if first.faces_uvs_padded() is not None:
|
||||
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
|
||||
F = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
|
||||
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
|
||||
|
||||
if first.verts_rgb_padded() is not None:
|
||||
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
|
||||
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
|
||||
V = max(i.shape[0] for i in rgb)
|
||||
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
|
||||
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
|
||||
|
||||
tex = Textures(**kwargs)
|
||||
tex = first.join_batch(all_textures[1:])
|
||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
|
||||
@@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
|
||||
"""
|
||||
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
|
||||
objects as a single mesh. If the input is a list, the Meshes objects in the list
|
||||
must all be on the same device. This version ignores all textures in the input mehses.
|
||||
must all be on the same device. This version ignores all textures in the input meshes.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from .utils import padded_to_list, padded_to_packed
|
||||
|
||||
|
||||
"""
|
||||
This file has functions for interpolating textures after rasterization.
|
||||
"""
|
||||
|
||||
|
||||
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Pad all texture images so they have the same height and width.
|
||||
|
||||
Args:
|
||||
images: list of N tensors of shape (H, W, 3)
|
||||
|
||||
Returns:
|
||||
tex_maps: Tensor of shape (N, max_H, max_W, 3)
|
||||
"""
|
||||
tex_maps = []
|
||||
max_H = 0
|
||||
max_W = 0
|
||||
for im in images:
|
||||
h, w, _3 = im.shape
|
||||
if h > max_H:
|
||||
max_H = h
|
||||
if w > max_W:
|
||||
max_W = w
|
||||
tex_maps.append(im)
|
||||
max_shape = (max_H, max_W)
|
||||
|
||||
for i, image in enumerate(tex_maps):
|
||||
if image.shape[:2] != max_shape:
|
||||
image_BCHW = image.permute(2, 0, 1)[None]
|
||||
new_image_BCHW = interpolate(
|
||||
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
|
||||
)
|
||||
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
|
||||
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
|
||||
return tex_maps
|
||||
|
||||
|
||||
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
|
||||
"""
|
||||
Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
|
||||
dimension. This is done in the following sequence of steps (where `B` is
|
||||
the batch dimension):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input_tensor (B, ...)
|
||||
-> add leading empty dimension (1, B, ...)
|
||||
-> expand (N, B, ...)
|
||||
-> reshape (N * B, ...)
|
||||
|
||||
Args:
|
||||
input_tensor: torch.Tensor with ndim > 2 representing a batched input.
|
||||
N: number of times to extend each element of the batch.
|
||||
"""
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if input_tensor.ndim < 2:
|
||||
raise ValueError("Input tensor must have ndimensions >= 2.")
|
||||
B = input_tensor.shape[0]
|
||||
non_batch_dims = tuple(input_tensor.shape[1:])
|
||||
constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded.
|
||||
return (
|
||||
input_tensor.clone()[None, ...]
|
||||
.expand(N, *constant_dims)
|
||||
.transpose(0, 1)
|
||||
.reshape(N * B, *non_batch_dims)
|
||||
)
|
||||
|
||||
|
||||
class Textures(object):
|
||||
def __init__(
|
||||
self,
|
||||
maps: Union[List, torch.Tensor, None] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
verts_rgb: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
maps: texture map per mesh. This can either be a list of maps
|
||||
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
|
||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
|
||||
vertex in the face. Padding value is assumed to be -1.
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
|
||||
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
|
||||
value is assumed to be -1.
|
||||
|
||||
Note: only the padded representation of the textures is stored
|
||||
and the packed/list representations are computed on the fly and
|
||||
not cached.
|
||||
"""
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if faces_uvs is not None and faces_uvs.ndim != 3:
|
||||
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
||||
raise ValueError(msg % repr(faces_uvs.shape))
|
||||
if verts_uvs is not None and verts_uvs.ndim != 3:
|
||||
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
||||
raise ValueError(msg % repr(verts_uvs.shape))
|
||||
if verts_rgb is not None and verts_rgb.ndim != 3:
|
||||
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
|
||||
raise ValueError(msg % repr(verts_rgb.shape))
|
||||
if maps is not None:
|
||||
# pyre-fixme[16]: `List` has no attribute `ndim`.
|
||||
if torch.is_tensor(maps) and maps.ndim != 4:
|
||||
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
|
||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
elif isinstance(maps, list):
|
||||
maps = _pad_texture_maps(maps)
|
||||
if faces_uvs is None or verts_uvs is None:
|
||||
msg = "To use maps, faces_uvs and verts_uvs are required"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._faces_uvs_padded = faces_uvs
|
||||
self._verts_uvs_padded = verts_uvs
|
||||
self._verts_rgb_padded = verts_rgb
|
||||
self._maps_padded = maps
|
||||
|
||||
# The number of faces/verts for each mesh is
|
||||
# set inside the Meshes object when textures is
|
||||
# passed into the Meshes constructor.
|
||||
self._num_faces_per_mesh = None
|
||||
self._num_verts_per_mesh = None
|
||||
|
||||
def clone(self):
|
||||
other = self.__class__()
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.clone())
|
||||
return other
|
||||
|
||||
def to(self, device):
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v) and v.device != device:
|
||||
setattr(self, k, v.to(device))
|
||||
return self
|
||||
|
||||
def __getitem__(self, index):
|
||||
other = self.__class__()
|
||||
for key in dir(self):
|
||||
value = getattr(self, key)
|
||||
if torch.is_tensor(value):
|
||||
if isinstance(index, int):
|
||||
setattr(other, key, value[index][None])
|
||||
else:
|
||||
setattr(other, key, value[index])
|
||||
return other
|
||||
|
||||
def faces_uvs_padded(self) -> torch.Tensor:
|
||||
# pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`.
|
||||
return self._faces_uvs_padded
|
||||
|
||||
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._faces_uvs_padded,
|
||||
split_size=self._num_faces_per_mesh,
|
||||
)
|
||||
|
||||
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._faces_uvs_padded,
|
||||
split_size=self._num_faces_per_mesh,
|
||||
)
|
||||
|
||||
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_uvs_padded
|
||||
|
||||
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
return padded_to_list(self._verts_uvs_padded)
|
||||
|
||||
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
return padded_to_packed(self._verts_uvs_padded)
|
||||
|
||||
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_rgb_padded
|
||||
|
||||
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._verts_rgb_padded,
|
||||
split_size=self._num_verts_per_mesh,
|
||||
)
|
||||
|
||||
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._verts_rgb_padded,
|
||||
split_size=self._num_verts_per_mesh,
|
||||
)
|
||||
|
||||
# Currently only the padded maps are used.
|
||||
def maps_padded(self) -> Union[torch.Tensor, None]:
|
||||
# pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None,
|
||||
# List[typing.Any], torch.Tensor]`.
|
||||
return self._maps_padded
|
||||
|
||||
def extend(self, N: int) -> "Textures":
|
||||
"""
|
||||
Create new Textures class which contains each input texture N times
|
||||
|
||||
Args:
|
||||
N: number of new copies of each texture.
|
||||
|
||||
Returns:
|
||||
new Textures object.
|
||||
"""
|
||||
if not isinstance(N, int):
|
||||
raise ValueError("N must be an integer.")
|
||||
if N <= 0:
|
||||
raise ValueError("N must be > 0.")
|
||||
|
||||
if all(
|
||||
v is not None
|
||||
for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
|
||||
):
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None,
|
||||
# List[typing.Any], torch.Tensor]`.
|
||||
new_maps = _extend_tensor(self._maps_padded, N)
|
||||
return self.__class__(
|
||||
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
|
||||
)
|
||||
elif self._verts_rgb_padded is not None:
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
|
||||
return self.__class__(verts_rgb=new_verts_rgb)
|
||||
else:
|
||||
msg = "Either vertex colors or texture maps are required."
|
||||
raise ValueError(msg)
|
||||
@@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if x.ndim != 3:
|
||||
raise ValueError("Supports only 3-dimensional input tensors")
|
||||
|
||||
x_list = list(x.unbind(0))
|
||||
|
||||
if split_size is None:
|
||||
|
||||
Reference in New Issue
Block a user