Texturing API updates

Summary:
A fairly big refactor of the texturing API with some breaking changes to how textures are defined.

Main changes:
- There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class:
   - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub.
  -  has a `join_batch` method for joining multiple textures of the same type into a batch

Reviewed By: gkioxari

Differential Revision: D21067427

fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
Nikhila Ravi 2020-07-29 16:06:58 -07:00 committed by Facebook GitHub Bot
parent b73d3d6ed9
commit a3932960b3
19 changed files with 1872 additions and 785 deletions

View File

@ -520,6 +520,9 @@
],
"metadata": {
"accelerator": "GPU",
"anp_metadata": {
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
@ -533,6 +536,9 @@
"provenance": [],
"toc_visible": true
},
"disseminate_notebook_info": {
"backup_notebook_id": "1062179640844868"
},
"kernelspec": {
"display_name": "pytorch3d (local)",
"language": "python",

View File

@ -84,7 +84,7 @@
"from skimage.io import imread\n",
"\n",
"# Util function for loading meshes\n",
"from pytorch3d.io import load_objs_as_meshes\n",
"from pytorch3d.io import load_objs_as_meshes, load_obj\n",
"\n",
"# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes, Textures\n",
@ -97,7 +97,7 @@
" RasterizationSettings, \n",
" MeshRenderer, \n",
" MeshRasterizer, \n",
" TexturedSoftPhongShader\n",
" SoftPhongShader\n",
")\n",
"\n",
"# add path for demo utils functions \n",
@ -316,7 +316,7 @@
" cameras=cameras, \n",
" raster_settings=raster_settings\n",
" ),\n",
" shader=TexturedSoftPhongShader(\n",
" shader=SoftPhongShader(\n",
" device=device, \n",
" cameras=cameras,\n",
" lights=lights\n",
@ -563,6 +563,9 @@
],
"metadata": {
"accelerator": "GPU",
"anp_metadata": {
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/render_textured_meshes.ipynb"
},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
@ -575,6 +578,9 @@
"name": "render_textured_meshes.ipynb",
"provenance": []
},
"disseminate_notebook_info": {
"backup_notebook_id": "569222367081034"
},
"kernelspec": {
"display_name": "pytorch3d (local)",
"language": "python",

View File

@ -13,8 +13,8 @@ from pytorch3d.renderer import (
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
TexturesVertex,
)
from pytorch3d.structures import Textures
class ShapeNetBase(torch.utils.data.Dataset):
@ -113,8 +113,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
"""
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
meshes.textures = Textures(
verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
meshes.textures = TexturesVertex(
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
)
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
renderer = MeshRenderer(

View File

@ -11,7 +11,8 @@ import numpy as np
import torch
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _open_file
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.renderer import TexturesAtlas, TexturesUV
from pytorch3d.structures import Meshes, join_meshes_as_batch
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
Args:
faces_indices: List of ints of indices.
max_index: Max index for the face property.
pad_value: if any of the face_indices are padded, specify
the value of the padding (e.g. -1). This is only used
for texture indices indices where there might
not be texture information for all the faces.
Returns:
faces_indices: List of ints of indices.
@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
faces_indices[mask] = pad_value
# Check indices are valid.
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
if torch.any(faces_indices >= max_index) or (
pad_value is None and torch.any(faces_indices < 0)
):
warnings.warn("Faces have invalid indices")
return faces_indices
@ -227,7 +234,14 @@ def load_obj(
)
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
def load_objs_as_meshes(
files: list,
device=None,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
):
"""
Load meshes from a list of .obj files using the load_obj function, and
return them as a Meshes object. This only works for meshes which have a
@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
"""
mesh_list = []
for f_obj in files:
# TODO: update this function to support the two texturing options.
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
verts = verts.to(device)
verts, faces, aux = load_obj(
f_obj,
load_textures=load_textures,
create_texture_atlas=create_texture_atlas,
texture_atlas_size=texture_atlas_size,
texture_wrap=texture_wrap,
)
tex = None
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
if create_texture_atlas:
# TexturesAtlas type
tex = TexturesAtlas(atlas=[aux.texture_atlas])
else:
# TexturesUV type
tex_maps = aux.texture_images
if tex_maps is not None and len(tex_maps) > 0:
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
faces_uvs = faces.textures_idx.to(device) # (F, 3)
image = list(tex_maps.values())[0].to(device)[None]
tex = TexturesUV(
verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image
)
mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
mesh = Meshes(
verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex
)
mesh_list.append(mesh)
if len(mesh_list) == 1:
return mesh_list[0]

View File

@ -28,11 +28,11 @@ from .mesh import (
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
TexturedSoftPhongShader,
Textures,
TexturesAtlas,
TexturesUV,
TexturesVertex,
gouraud_shading,
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
phong_shading,
rasterize_meshes,
)

View File

@ -1,10 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .texturing import interpolate_texture_map, interpolate_vertex_colors # isort:skip
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer
from .shader import TexturedSoftPhongShader # DEPRECATED
from .shader import (
HardFlatShader,
HardGouraudShader,
@ -12,10 +12,10 @@ from .shader import (
SoftGouraudShader,
SoftPhongShader,
SoftSilhouetteShader,
TexturedSoftPhongShader,
)
from .shading import gouraud_shading, phong_shading
from .utils import interpolate_face_attributes
from .textures import Textures # DEPRECATED
from .textures import TexturesAtlas, TexturesUV, TexturesVertex
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
import torch
import torch.nn as nn
@ -13,7 +14,6 @@ from ..blending import (
from ..lighting import PointLights
from ..materials import Materials
from .shading import flat_shading, gouraud_shading, phong_shading
from .texturing import interpolate_texture_map, interpolate_vertex_colors
# A Shader should take as input fragments from the output of rasterization
@ -57,7 +57,7 @@ class HardPhongShader(nn.Module):
or in the forward pass of HardPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
@ -104,9 +104,11 @@ class SoftPhongShader(nn.Module):
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
@ -115,7 +117,7 @@ class SoftPhongShader(nn.Module):
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
images = softmax_rgb_blend(colors, fragments, blend_params)
return images
@ -154,6 +156,12 @@ class HardGouraudShader(nn.Module):
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
# As Gouraud shading applies the illumination to the vertex
# colors, the interpolated pixel texture is calculated in the
# shading step. In comparison, for Phong shading, the pixel
# textures are computed first after which the illumination is
# applied.
pixel_colors = gouraud_shading(
meshes=meshes,
fragments=fragments,
@ -210,54 +218,25 @@ class SoftGouraudShader(nn.Module):
return images
class TexturedSoftPhongShader(nn.Module):
def TexturedSoftPhongShader(
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
"""
Per pixel lighting applied to a texture map. First interpolate the vertex
uv coordinates and sample from a texture map. Then apply the lighting model
using the interpolated coords and normals for each pixel.
The blending function returns the soft aggregated color using all
the faces per pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = TexturedPhongShader(device=torch.device("cuda:0"))
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
Preserving TexturedSoftPhongShader as a function for backwards compatibility.
"""
def __init__(
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
):
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of TexturedSoftPhongShader"
raise ValueError(msg)
texels = interpolate_texture_map(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, blend_params)
return images
warnings.warn(
"""TexturedSoftPhongShader is now deprecated;
use SoftPhongShader instead.""",
PendingDeprecationWarning,
)
return SoftPhongShader(
device=device,
cameras=cameras,
lights=lights,
materials=materials,
blend_params=blend_params,
)
class HardFlatShader(nn.Module):
@ -291,7 +270,7 @@ class HardFlatShader(nn.Module):
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardFlatShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
texels = meshes.sample_textures(fragments)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)

View File

@ -6,6 +6,8 @@ from typing import Tuple
import torch
from pytorch3d.ops import interpolate_face_attributes
from .textures import TexturesVertex
def _apply_lighting(
points, normals, lights, cameras, materials
@ -91,6 +93,9 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
Then interpolate the vertex shaded colors using the barycentric coordinates
to get a color per pixel.
Gouraud shading is only supported for meshes with texture type `TexturesVertex`.
This is because the illumination is applied to the vertex colors.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
@ -101,10 +106,13 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
Returns:
colors: (N, H, W, K, 3)
"""
if not isinstance(meshes.textures, TexturesVertex):
raise ValueError("Mesh textures must be an instance of TexturesVertex")
faces = meshes.faces_packed() # (F, 3)
verts = meshes.verts_packed()
vertex_normals = meshes.verts_normals_packed() # (V, 3)
vertex_colors = meshes.textures.verts_rgb_packed()
verts = meshes.verts_packed() # (V, 3)
verts_normals = meshes.verts_normals_packed() # (V, 3)
verts_colors = meshes.textures.verts_features_packed() # (V, D)
vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()
# Format properties of lights and materials so they are compatible
@ -119,9 +127,10 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
# Calculate the illumination at each vertex
ambient, diffuse, specular = _apply_lighting(
verts, vertex_normals, lights, cameras, materials
verts, verts_normals, lights, cameras, materials
)
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
verts_colors_shaded = verts_colors * (ambient + diffuse) + specular
face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, face_colors

File diff suppressed because it is too large Load Diff

View File

@ -1,113 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.textures import Textures
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
"""
Interpolate a 2D texture map using uv vertex texture coordinates for each
face in the mesh. First interpolate the vertex uvs using barycentric coordinates
for each pixel in the rasterized output. Then interpolate the texture map
using the uv coordinate for each pixel.
Args:
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
meshes: Meshes representing a batch of meshes. It is expected that
meshes has a textures attribute which is an instance of the
Textures class.
Returns:
texels: tensor of shape (N, H, W, K, C) giving the interpolated
texture for each pixel in the rasterized image.
"""
if not isinstance(meshes.textures, Textures):
msg = "Expected meshes.textures to be an instance of Textures; got %r"
raise ValueError(msg % type(meshes.textures))
faces_uvs = meshes.textures.faces_uvs_packed()
verts_uvs = meshes.textures.verts_uvs_packed()
faces_verts_uvs = verts_uvs[faces_uvs]
texture_maps = meshes.textures.maps_padded()
# pixel_uvs: (N, H, W, K, 2)
pixel_uvs = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
)
N, H_out, W_out, K = fragments.pix_to_face.shape
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
# pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
# textures.map:
# (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
# -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
texture_maps = (
texture_maps.permute(0, 3, 1, 2)[None, ...]
.expand(K, -1, -1, -1, -1)
.transpose(0, 1)
.reshape(N * K, C, H_in, W_in)
)
# Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
# Now need to format the pixel uvs and the texture map correctly!
# From pytorch docs, grid_sample takes `grid` and `input`:
# grid specifies the sampling pixel locations normalized by
# the input spatial dimensions It should have most
# values in the range of [-1, 1]. Values x = -1, y = -1
# is the left-top pixel of input, and values x = 1, y = 1 is the
# right-bottom pixel of input.
pixel_uvs = pixel_uvs * 2.0 - 1.0
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
"""
Detemine the color for each rasterized face. Interpolate the colors for
vertices which form the face using the barycentric coordinates.
Args:
meshes: A Meshes class representing a batch of meshes.
fragments:
The outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
the barycentric coordianates of each pixel
relative to the faces (in the packed
representation) which overlap the pixel.
Returns:
texels: An texture per pixel of shape (N, H, W, K, C).
There will be one C dimensional value for each element in
fragments.pix_to_face.
"""
vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C)
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
faces_packed = meshes.faces_packed()
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_textures
)
return texels

View File

@ -2,7 +2,6 @@
from .meshes import Meshes, join_meshes_as_batch
from .pointclouds import Pointclouds
from .textures import Textures
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list

View File

@ -5,7 +5,6 @@ from typing import List, Union
import torch
from . import utils as struct_utils
from .textures import Textures
class Meshes(object):
@ -234,9 +233,9 @@ class Meshes(object):
Refer to comments above for descriptions of List and Padded representations.
"""
self.device = None
if textures is not None and not isinstance(textures, Textures):
msg = "Expected textures to be of type Textures; got %r"
raise ValueError(msg % type(textures))
if textures is not None and not repr(textures) == "TexturesBase":
msg = "Expected textures to be an instance of type TexturesBase; got %r"
raise ValueError(msg % repr(textures))
self.textures = textures
# Indicates whether the meshes in the list/batch have the same number
@ -400,6 +399,8 @@ class Meshes(object):
if self.textures is not None:
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
self.textures._N = self._N
self.textures.valid = self.valid
def __len__(self):
return self._N
@ -1465,6 +1466,17 @@ class Meshes(object):
return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)
def sample_textures(self, fragments):
if self.textures is not None:
# Pass in faces packed. If the textures are defined per
# vertex, the face indices are needed in order to interpolate
# the vertex attributes across the face.
return self.textures.sample_textures(
fragments, faces_packed=self.faces_packed()
)
else:
raise ValueError("Meshes does not have textures")
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
"""
@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
raise ValueError("Inconsistent textures in join_meshes_as_batch.")
# Now we know there are multiple meshes and they have textures to merge.
first = meshes[0].textures
kwargs = {}
if first.maps_padded() is not None:
if any(mesh.textures.maps_padded() is None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
kwargs["maps"] = maps
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
all_textures = [mesh.textures for mesh in meshes]
first = all_textures[0]
tex_types_same = all(type(tex) == type(first) for tex in all_textures)
if first.verts_uvs_padded() is not None:
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
V = max(uv.shape[0] for uv in uvs)
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
if not tex_types_same:
raise ValueError("All meshes in the batch must have the same type of texture.")
if first.faces_uvs_padded() is not None:
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
F = max(uv.shape[0] for uv in uvs)
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
if first.verts_rgb_padded() is not None:
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
V = max(i.shape[0] for i in rgb)
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
tex = Textures(**kwargs)
tex = first.join_batch(all_textures[1:])
return Meshes(verts=verts, faces=faces, textures=tex)
@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
"""
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
objects as a single mesh. If the input is a list, the Meshes objects in the list
must all be on the same device. This version ignores all textures in the input mehses.
must all be on the same device. This version ignores all textures in the input meshes.
Args:
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects

View File

@ -1,279 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Optional, Union
import torch
from torch.nn.functional import interpolate
from .utils import padded_to_list, padded_to_packed
"""
This file has functions for interpolating textures after rasterization.
"""
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (H, W, 3)
Returns:
tex_maps: Tensor of shape (N, max_H, max_W, 3)
"""
tex_maps = []
max_H = 0
max_W = 0
for im in images:
h, w, _3 = im.shape
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_H, max_W)
for i, image in enumerate(tex_maps):
if image.shape[:2] != max_shape:
image_BCHW = image.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
)
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
return tex_maps
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
"""
Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
dimension. This is done in the following sequence of steps (where `B` is
the batch dimension):
.. code-block:: python
input_tensor (B, ...)
-> add leading empty dimension (1, B, ...)
-> expand (N, B, ...)
-> reshape (N * B, ...)
Args:
input_tensor: torch.Tensor with ndim > 2 representing a batched input.
N: number of times to extend each element of the batch.
"""
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if input_tensor.ndim < 2:
raise ValueError("Input tensor must have ndimensions >= 2.")
B = input_tensor.shape[0]
non_batch_dims = tuple(input_tensor.shape[1:])
constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded.
return (
input_tensor.clone()[None, ...]
.expand(N, *constant_dims)
.transpose(0, 1)
.reshape(N * B, *non_batch_dims)
)
class Textures(object):
def __init__(
self,
maps: Union[List, torch.Tensor, None] = None,
faces_uvs: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
verts_rgb: Optional[torch.Tensor] = None,
):
"""
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
vertex in the face. Padding value is assumed to be -1.
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
value is assumed to be -1.
Note: only the padded representation of the textures is stored
and the packed/list representations are computed on the fly and
not cached.
"""
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if faces_uvs is not None and faces_uvs.ndim != 3:
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
raise ValueError(msg % repr(faces_uvs.shape))
if verts_uvs is not None and verts_uvs.ndim != 3:
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
raise ValueError(msg % repr(verts_uvs.shape))
if verts_rgb is not None and verts_rgb.ndim != 3:
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
raise ValueError(msg % repr(verts_rgb.shape))
if maps is not None:
# pyre-fixme[16]: `List` has no attribute `ndim`.
if torch.is_tensor(maps) and maps.ndim != 4:
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
# pyre-fixme[16]: `List` has no attribute `shape`.
raise ValueError(msg % repr(maps.shape))
elif isinstance(maps, list):
maps = _pad_texture_maps(maps)
if faces_uvs is None or verts_uvs is None:
msg = "To use maps, faces_uvs and verts_uvs are required"
raise ValueError(msg)
self._faces_uvs_padded = faces_uvs
self._verts_uvs_padded = verts_uvs
self._verts_rgb_padded = verts_rgb
self._maps_padded = maps
# The number of faces/verts for each mesh is
# set inside the Meshes object when textures is
# passed into the Meshes constructor.
self._num_faces_per_mesh = None
self._num_verts_per_mesh = None
def clone(self):
other = self.__class__()
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.clone())
return other
def to(self, device):
for k in dir(self):
v = getattr(self, k)
if torch.is_tensor(v) and v.device != device:
setattr(self, k, v.to(device))
return self
def __getitem__(self, index):
other = self.__class__()
for key in dir(self):
value = getattr(self, key)
if torch.is_tensor(value):
if isinstance(index, int):
setattr(other, key, value[index][None])
else:
setattr(other, key, value[index])
return other
def faces_uvs_padded(self) -> torch.Tensor:
# pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`.
return self._faces_uvs_padded
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
if self._faces_uvs_padded is None:
return None
return padded_to_list(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._faces_uvs_padded,
split_size=self._num_faces_per_mesh,
)
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
if self._faces_uvs_padded is None:
return None
return padded_to_packed(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._faces_uvs_padded,
split_size=self._num_faces_per_mesh,
)
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
return self._verts_uvs_padded
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
if self._verts_uvs_padded is None:
return None
# Vertices shared between multiple faces
# may have a different uv coordinate for
# each face so the num_verts_uvs_per_mesh
# may be different from num_verts_per_mesh.
# Therefore don't use any split_size.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
return padded_to_list(self._verts_uvs_padded)
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
if self._verts_uvs_padded is None:
return None
# Vertices shared between multiple faces
# may have a different uv coordinate for
# each face so the num_verts_uvs_per_mesh
# may be different from num_verts_per_mesh.
# Therefore don't use any split_size.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
return padded_to_packed(self._verts_uvs_padded)
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
return self._verts_rgb_padded
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
if self._verts_rgb_padded is None:
return None
return padded_to_list(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._verts_rgb_padded,
split_size=self._num_verts_per_mesh,
)
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
if self._verts_rgb_padded is None:
return None
return padded_to_packed(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
self._verts_rgb_padded,
split_size=self._num_verts_per_mesh,
)
# Currently only the padded maps are used.
def maps_padded(self) -> Union[torch.Tensor, None]:
# pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None,
# List[typing.Any], torch.Tensor]`.
return self._maps_padded
def extend(self, N: int) -> "Textures":
"""
Create new Textures class which contains each input texture N times
Args:
N: number of new copies of each texture.
Returns:
new Textures object.
"""
if not isinstance(N, int):
raise ValueError("N must be an integer.")
if N <= 0:
raise ValueError("N must be > 0.")
if all(
v is not None
for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
):
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None,
# List[typing.Any], torch.Tensor]`.
new_maps = _extend_tensor(self._maps_padded, N)
return self.__class__(
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
)
elif self._verts_rgb_padded is not None:
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
return self.__class__(verts_rgb=new_verts_rgb)
else:
msg = "Either vertex colors or texture maps are required."
raise ValueError(msg)

View File

@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
if x.ndim != 3:
raise ValueError("Supports only 3-dimensional input tensors")
x_list = list(x.unbind(0))
if split_size is None:

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View File

@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import (
interpolate_face_attributes,
interpolate_face_attributes_python,
)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures import Meshes
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
def test_interpolate_attributes(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_interpolate_attributes_grad(self):
@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
dtype=torch.float32,
requires_grad=True,
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))

View File

@ -13,8 +13,8 @@ from pytorch3d.io.mtl_io import (
_bilinear_interpolation_grid_sample,
_bilinear_interpolation_vectorized,
)
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
from pytorch3d.structures.meshes import join_mesh
from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.structures import Meshes, join_meshes_as_batch
from pytorch3d.utils import torus
@ -590,17 +590,29 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
check_item(mesh.verts_padded(), mesh3.verts_padded())
check_item(mesh.faces_padded(), mesh3.faces_padded())
if mesh.textures is not None:
check_item(mesh.textures.maps_padded(), mesh3.textures.maps_padded())
check_item(
mesh.textures.faces_uvs_padded(), mesh3.textures.faces_uvs_padded()
)
check_item(
mesh.textures.verts_uvs_padded(), mesh3.textures.verts_uvs_padded()
)
check_item(
mesh.textures.verts_rgb_padded(), mesh3.textures.verts_rgb_padded()
)
if isinstance(mesh.textures, TexturesUV):
check_item(
mesh.textures.faces_uvs_padded(),
mesh3.textures.faces_uvs_padded(),
)
check_item(
mesh.textures.verts_uvs_padded(),
mesh3.textures.verts_uvs_padded(),
)
check_item(
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
)
elif isinstance(mesh.textures, TexturesVertex):
check_item(
mesh.textures.verts_features_padded(),
mesh3.textures.verts_features_padded(),
)
elif isinstance(mesh.textures, TexturesAtlas):
check_item(
mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
)
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
@ -623,16 +635,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
check_triple(mesh_notex, mesh3_notex)
self.assertIsNone(mesh_notex.textures)
# meshes with vertex texture, join into a batch.
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
vert_tex = torch.ones_like(verts)
rgb_tex = TexturesVertex(verts_features=[vert_tex])
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
# meshes with texture atlas, join into a batch.
device = "cuda:0"
atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device)
atlas_tex = TexturesAtlas(atlas=[atlas])
mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex)
mesh_atlas3 = join_meshes_as_batch([mesh_atlas, mesh_atlas, mesh_atlas])
check_triple(mesh_atlas, mesh_atlas3)
# Test load multiple meshes with textures into a batch.
teapot_obj = DATA_DIR / "teapot.obj"
mesh_teapot = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
@ -649,41 +669,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
def test_join_meshes(self):
"""
Test that join_mesh joins single meshes and the corresponding values are
consistent with the single meshes.
"""
# Load cow mesh.
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
cow_obj = DATA_DIR / "cow_mesh/cow.obj"
cow_mesh = load_objs_as_meshes([cow_obj])
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
# Join a batch of three single meshes and check that the values are consistent
# with the individual meshes.
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
def check_item(x, y, offset):
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
# Test the joining of meshes of different sizes.
teapot_obj = DATA_DIR / "teapot.obj"
teapot_mesh = load_objs_as_meshes([teapot_obj])
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
mix_mesh = join_mesh([cow_mesh, teapot_mesh])
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
self.assertEqual(len(mix_mesh), 1)
self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
# Check error raised if all meshes in the batch don't have the same texture type
with self.assertRaisesRegex(ValueError, "same type of texture"):
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
@staticmethod
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):

View File

@ -11,10 +11,11 @@ import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import (
SoftSilhouetteShader,
TexturedSoftPhongShader,
)
from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes, join_mesh
from pytorch3d.utils.ico_sphere import ico_sphere
@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# Init rasterizer settings
@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
"""
batch_size = 20
batch_size = 5
device = torch.device("cuda:0")
# Init mesh with vertex textures.
sphere_meshes = ico_sphere(5, device).extend(batch_size)
verts_padded = sphere_meshes.verts_padded()
faces_padded = sphere_meshes.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
feats = torch.ones_like(verts_padded, device=device)
textures = TexturesVertex(verts_features=feats)
sphere_meshes = Meshes(
verts=verts_padded, faces=faces_padded, textures=textures
)
@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_simple_sphere_batched_%s.png" % name
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.05)
def test_silhouette_with_grad(self):
@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
with Image.open(image_ref_filename) as raw_image_ref:
image_ref = torch.from_numpy(np.array(raw_image_ref))
image_ref = image_ref.to(dtype=torch.float32) / 255.0
self.assertClose(alpha, image_ref, atol=0.055)
@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture
mesh = load_objs_as_meshes([obj_filename], device=device)
verts, faces, aux = load_obj(
obj_filename, device=device, load_textures=True, texture_wrap=None
)
tex_map = list(aux.texture_images.values())[0]
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
textures = TexturesUV(
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
)
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
)
joined_sphere_mesh = join_mesh(sphere_mesh_list)
joined_sphere_mesh.textures = Textures(
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
joined_sphere_mesh.textures = TexturesVertex(
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
)
# Init rasterizer settings
@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
)
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
def test_texture_map_atlas(self):
"""
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
"""
device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh and texture as a per face texture atlas.
verts, faces, aux = load_obj(
obj_filename,
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
mesh = Meshes(
verts=[verts],
faces=[faces.verts_idx],
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
)
# Init shader settings
materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0)
lights = PointLights(device=device)
# Place light behind the cow in world space. The front of
# the cow is facing the -z direction.
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
# The HardPhongShader can be used directly with atlas textures.
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
)
self.assertClose(rgb, image_ref, atol=0.05)

View File

@ -7,14 +7,376 @@ import torch
import torch.nn.functional as F
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import interpolate_texture_map
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures.utils import list_to_padded
from pytorch3d.renderer.mesh.textures import (
TexturesAtlas,
TexturesUV,
TexturesVertex,
_list_to_padded_wrapper,
)
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from test_meshes import TestMeshes
class TestTexturing(TestCaseMixin, unittest.TestCase):
def test_interpolate_texture_map(self):
def tryindex(self, index, tex, meshes, source):
tex2 = tex[index]
meshes2 = meshes[index]
tex_from_meshes = meshes2.textures
for item in source:
basic = source[item][index]
from_texture = getattr(tex2, item + "_padded")()
from_meshes = getattr(tex_from_meshes, item + "_padded")()
if isinstance(index, int):
basic = basic[None]
if len(basic) == 0:
self.assertEquals(len(from_texture), 0)
self.assertEquals(len(from_meshes), 0)
else:
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim)
item_list = getattr(tex_from_meshes, item + "_list")()
self.assertEqual(basic.shape[0], len(item_list))
for i, elem in enumerate(item_list):
self.assertClose(elem, basic[i])
class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
def test_sample_vertex_textures(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
# sample_textures calls interpolate_vertex_colors
texels = mesh.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_sample_vertex_textures_grad(self):
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
dtype=torch.float32,
requires_grad=True,
)
verts_features = vert_tex
tex = TexturesVertex(verts_features=[verts_features])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
grad_vert_tex = torch.tensor(
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = mesh.sample_textures(fragments)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
def test_textures_vertex_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=torch.rand(size=(5, 10)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "verts_features"):
TexturesVertex(verts_features=(1, 1, 1))
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex_cloned = tex.clone()
self.assertSeparate(
tex._verts_features_padded, tex_cloned._verts_features_padded
)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
mesh = TestMeshes.init_mesh(B, 30, 50)
V = mesh._V
tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.verts_features_list()[i],
new_tex.verts_features_list()[i * N + n],
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[tex_init.verts_features_padded(), new_tex.verts_features_padded()]
)
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
num_verts_per_mesh = [9, 6]
D = 10
verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh]
verts_features_packed = list_to_packed(verts_features_list)[0]
verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh)
tex = TexturesVertex(verts_features=verts_features_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_features_packed()
verts_verts_list = tex1.verts_features_list()
verts_padded = tex1.verts_features_padded()
for f1, f2 in zip(verts_verts_list, verts_features_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D))
self.assertTrue(verts_padded.shape == (2, 9, D))
# Case where num_verts_per_mesh is not set and textures
# are initialized with a padded tensor.
tex2 = TexturesVertex(verts_features=verts_padded)
verts_packed = tex2.verts_features_packed()
verts_list = tex2.verts_features_list()
# Packed is just flattened padded as num_verts_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, D))
for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
source = {"verts_features": torch.randn(size=(N, 10, 128))}
tex = TexturesVertex(verts_features=source["verts_features"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
def test_sample_texture_atlas(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3))
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
expected_vals = torch.tensor(
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
)
expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32)
expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...]
expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...]
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
self.assertTrue(torch.allclose(texels, expected_vals))
def test_textures_atlas_grad(self):
N, F, R = 1, 2, 2
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True)
tex = TexturesAtlas(atlas=faces_atlas)
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = mesh.textures.sample_textures(fragments)
grad_tex = torch.rand_like(texels)
grad_expected = torch.zeros_like(faces_atlas)
grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :]
grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :]
texels.backward(grad_tex)
self.assertTrue(hasattr(faces_atlas, "grad"))
self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected))
def test_textures_atlas_init_fail(self):
# Incorrect sized tensors
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=torch.rand(size=(5, 10, 3)))
# Not a list or a tensor
with self.assertRaisesRegex(ValueError, "atlas"):
TexturesAtlas(atlas=(1, 1, 1))
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex_cloned = tex.clone()
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
mesh = TestMeshes.init_mesh(B, 30, 50)
F = mesh._F
tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3)))
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
)
N = 20
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
tex_init = tex_mesh.textures
new_tex = new_mesh.textures
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()])
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
R = 2
N = 20
num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30)
atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh]
tex = TexturesAtlas(atlas=atlas_list)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh.tolist()
atlas_packed = tex1.atlas_packed()
atlas_list_new = tex1.atlas_list()
atlas_padded = tex1.atlas_padded()
for f1, f2 in zip(atlas_list_new, atlas_list):
self.assertTrue((f1 == f2).all().item())
sum_F = num_faces_per_mesh.sum()
max_F = num_faces_per_mesh.max().item()
self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3))
self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3))
# Case where num_faces_per_mesh is not set and textures
# are initialized with a padded tensor.
atlas_list_padded = _list_to_padded_wrapper(atlas_list)
tex2 = TexturesAtlas(atlas=atlas_list_padded)
atlas_packed = tex2.atlas_packed()
atlas_list_new = tex2.atlas_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3))
for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_getitem(self):
N = 5
V = 20
source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
tex = TexturesAtlas(atlas=source["atlas"])
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
def test_sample_textures_uv(self):
barycentric_coords = torch.tensor(
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
@ -38,11 +400,11 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
zbuf=pix_to_face,
dists=pix_to_face,
)
tex = Textures(
maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...]
)
tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
texels = interpolate_texture_map(fragments, meshes)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)
# Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0
@ -53,190 +415,92 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
def test_init_rgb_uv_fail(self):
V = 20
def test_textures_uv_init_fail(self):
# Maps has wrong shape
with self.assertRaisesRegex(ValueError, "maps"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3, 4)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# faces_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "faces_uvs"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts_uvs has wrong shape
with self.assertRaisesRegex(ValueError, "verts_uvs"):
Textures(
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2, 3)),
)
# verts_rgb has wrong shape
with self.assertRaisesRegex(ValueError, "verts_rgb"):
Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
# maps provided without verts/faces uvs
with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"):
Textures(maps=torch.ones((5, 16, 16, 3)))
def test_padded_to_packed(self):
N = 2
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
verts_uvs_padded = list_to_padded(verts_uvs_list)
tex = Textures(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_padded,
verts_uvs=verts_uvs_padded,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
tex1._num_verts_per_mesh = torch.tensor([5, 4])
faces_packed = tex1.faces_uvs_packed()
verts_packed = tex1.verts_uvs_packed()
faces_list = tex1.faces_uvs_list()
verts_list = tex1.verts_uvs_list()
for f1, f2 in zip(faces_uvs_list, faces_list):
self.assertTrue((f1 == f2).all().item())
for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
idx = f.unique()
self.assertTrue((v1[idx] == v2).all().item())
self.assertTrue(faces_packed.shape == (3 + 2, 3))
# verts_packed is just flattened verts_padded.
# split sizes are not used for verts_uvs.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
# Case where num_faces_per_mesh is not set
tex2 = tex.clone()
faces_packed = tex2.faces_uvs_packed()
verts_packed = tex2.verts_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(verts_packed.shape == (9 * 2, 2))
self.assertTrue(faces_packed.shape == (3 * 2, 3))
for i in range(N):
self.assertTrue(
(faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item()
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3)),
)
for i in range(N):
self.assertTrue(
(verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item()
# verts has different batch dim to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(8, 15, 2)),
)
# maps has different batch dim to faces
with self.assertRaisesRegex(ValueError, "maps"):
TexturesUV(
maps=torch.ones((8, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# verts on different device to faces
with self.assertRaisesRegex(ValueError, "verts_uvs"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"),
)
# maps on different device to faces
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3), device="cuda"),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_clone(self):
V = 20
tex = Textures(
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
def test_getitem(self):
N = 5
V = 20
source = {
"maps": torch.rand(size=(N, 16, 16, 3)),
"faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
"verts_uvs": torch.rand((N, V, 2)),
}
tex = Textures(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
)
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
def tryindex(index):
tex2 = tex[index]
meshes2 = meshes[index]
tex_from_meshes = meshes2.textures
for item in source:
basic = source[item][index]
from_texture = getattr(tex2, item + "_padded")()
from_meshes = getattr(tex_from_meshes, item + "_padded")()
if isinstance(index, int):
basic = basic[None]
self.assertClose(basic, from_texture)
self.assertClose(basic, from_meshes)
self.assertEqual(
from_texture.ndim, getattr(tex, item + "_padded")().ndim
)
if item == "faces_uvs":
faces_uvs_list = tex_from_meshes.faces_uvs_list()
self.assertEqual(basic.shape[0], len(faces_uvs_list))
for i, faces_uvs in enumerate(faces_uvs_list):
self.assertClose(faces_uvs, basic[i])
tryindex(2)
tryindex(slice(0, 2, 1))
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(index)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(index)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(index)
tryindex([2, 4])
def test_to(self):
V = 20
tex = Textures(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
verts_uvs=torch.ones((5, V, 2)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertTrue(tex._faces_uvs_padded.device == device)
self.assertTrue(tex._verts_uvs_padded.device == device)
self.assertTrue(tex._maps_padded.device == device)
self.assertSeparate(tex.valid, tex_cloned.valid)
def test_extend(self):
B = 10
B = 5
mesh = TestMeshes.init_mesh(B, 30, 50)
V = mesh._V
F = mesh._F
# 1. Texture uvs
tex_uv = Textures(
maps=torch.randn((B, 16, 16, 3)),
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
verts_uvs=torch.randn((B, V, 2)),
num_faces = mesh.num_faces_per_mesh()
num_verts = mesh.num_verts_per_mesh()
faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces]
verts_uvs_list = [torch.rand(v, 2) for v in num_verts]
tex_uv = TexturesUV(
maps=torch.ones((B, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
tex_mesh = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv
)
N = 20
N = 2
new_mesh = tex_mesh.extend(N)
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
@ -246,56 +510,142 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
for i in range(len(tex_mesh)):
for n in range(N):
self.assertClose(
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
)
self.assertClose(
tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
)
self.assertClose(
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n]
)
self.assertClose(
tex_init._num_faces_per_mesh[i],
new_tex._num_faces_per_mesh[i * N + n],
)
self.assertAllSeparate(
[
tex_init.faces_uvs_padded(),
new_tex.faces_uvs_padded(),
tex_init.faces_uvs_packed(),
new_tex.faces_uvs_packed(),
tex_init.verts_uvs_padded(),
new_tex.verts_uvs_padded(),
tex_init.verts_uvs_packed(),
new_tex.verts_uvs_packed(),
tex_init.maps_padded(),
new_tex.maps_padded(),
]
)
self.assertIsNone(new_tex.verts_rgb_list())
self.assertIsNone(new_tex.verts_rgb_padded())
self.assertIsNone(new_tex.verts_rgb_packed())
# 2. Texture vertex RGB
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
tex_mesh_rgb = Meshes(
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb
)
N = 20
new_mesh_rgb = tex_mesh_rgb.extend(N)
self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
tex_init = tex_mesh_rgb.textures
new_tex = new_mesh_rgb.textures
for i in range(len(tex_mesh_rgb)):
for n in range(N):
self.assertClose(
tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n]
)
self.assertAllSeparate(
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
)
self.assertIsNone(new_tex.verts_uvs_padded())
self.assertIsNone(new_tex.verts_uvs_list())
self.assertIsNone(new_tex.verts_uvs_packed())
self.assertIsNone(new_tex.faces_uvs_padded())
self.assertIsNone(new_tex.faces_uvs_list())
self.assertIsNone(new_tex.faces_uvs_packed())
# 3. Error
with self.assertRaises(ValueError):
tex_mesh.extend(N=-1)
def test_padded_to_packed(self):
# Case where each face in the mesh has 3 unique uv vertex indices
# - i.e. even if a vertex is shared between multiple faces it will
# have a unique uv coordinate for each face.
N = 2
faces_uvs_list = [
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
tex = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
)
# This is set inside Meshes when textures is passed as an input.
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
tex1 = tex.clone()
tex1._num_faces_per_mesh = num_faces_per_mesh
tex1._num_verts_per_mesh = num_verts_per_mesh
verts_packed = tex1.verts_uvs_packed()
verts_list = tex1.verts_uvs_list()
verts_padded = tex1.verts_uvs_padded()
faces_packed = tex1.faces_uvs_packed()
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()
for f1, f2 in zip(faces_list, faces_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(faces_packed.shape == (3 + 2, 3))
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_packed.shape == (9 + 6, 2))
self.assertTrue(verts_padded.shape == (2, 9, 2))
# Case where num_faces_per_mesh is not set and faces_verts_uvs
# are initialized with a padded tensor.
tex2 = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
verts_uvs=verts_padded,
faces_uvs=faces_padded,
)
faces_packed = tex2.faces_uvs_packed()
faces_list = tex2.faces_uvs_list()
verts_packed = tex2.verts_uvs_packed()
verts_list = tex2.verts_uvs_list()
# Packed is just flattened padded as num_faces_per_mesh
# has not been provided.
self.assertTrue(faces_packed.shape == (3 * 2, 3))
self.assertTrue(verts_packed.shape == (9 * 2, 2))
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_to(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertTrue(tex._faces_uvs_padded.device == device)
self.assertTrue(tex._verts_uvs_padded.device == device)
self.assertTrue(tex._maps_padded.device == device)
def test_getitem(self):
N = 5
V = 20
source = {
"maps": torch.rand(size=(N, 1, 1, 3)),
"faces_uvs": torch.randint(size=(N, 10, 3), high=V),
"verts_uvs": torch.randn(size=(N, V, 2)),
}
tex = TexturesUV(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
)
verts = torch.rand(size=(N, V, 3))
faces = torch.randint(size=(N, 10, 3), high=V)
meshes = Meshes(verts=verts, faces=faces, textures=tex)
tryindex(self, 2, tex, meshes, source)
tryindex(self, slice(0, 2, 1), tex, meshes, source)
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
tryindex(self, index, tex, meshes, source)
index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source)