Texturing API updates

Summary:
A fairly big refactor of the texturing API with some breaking changes to how textures are defined.

Main changes:
- There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class:
   - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub.
  -  has a `join_batch` method for joining multiple textures of the same type into a batch

Reviewed By: gkioxari

Differential Revision: D21067427

fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
Nikhila Ravi
2020-07-29 16:06:58 -07:00
committed by Facebook GitHub Bot
parent b73d3d6ed9
commit a3932960b3
19 changed files with 1872 additions and 785 deletions

View File

@@ -11,7 +11,8 @@ import numpy as np
import torch
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _open_file
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
@@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
Args:
faces_indices: List of ints of indices.
max_index: Max index for the face property.
pad_value: if any of the face_indices are padded, specify
the value of the padding (e.g. -1). This is only used
for texture indices indices where there might
not be texture information for all the faces.
Returns:
faces_indices: List of ints of indices.
@@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices[mask] = pad_value
# Check indices are valid.
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
if torch.any(faces_indices >= max_index) or (
pad_value is None and torch.any(faces_indices < 0)
):
warnings.warn("Faces have invalid indices")
return faces_indices
@@ -227,7 +234,14 @@ def load_obj(
)
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
def load_objs_as_meshes(
files: list,
device=None,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
):
"""
Load meshes from a list of .obj files using the load_obj function, and
return them as a Meshes object. This only works for meshes which have a
@@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
"""
mesh_list = []
for f_obj in files:
# TODO: update this function to support the two texturing options.
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
verts = verts.to(device)
verts, faces, aux = load_obj(
f_obj,
load_textures=load_textures,
create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap,
)
tex = None
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
if create_texture_atlas:
# TexturesAtlas type
tex = TexturesAtlas(atlas=[aux.texture_atlas])
else:
# TexturesUV type
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
faces_uvs = faces.textures_idx.to(device) # (F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = TexturesUV(
verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image
)
mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
mesh = Meshes(
verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex
)
mesh_list.append(mesh)
if len(mesh_list) == 1:
return mesh_list[0]