diff --git a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb index fbbffb6d..5bafce91 100644 --- a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb +++ b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb @@ -520,6 +520,9 @@ ], "metadata": { "accelerator": "GPU", + "anp_metadata": { + "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb" + }, "bento_stylesheets": { "bento/extensions/flow/main.css": true, "bento/extensions/kernel_selector/main.css": true, @@ -533,6 +536,9 @@ "provenance": [], "toc_visible": true }, + "disseminate_notebook_info": { + "backup_notebook_id": "1062179640844868" + }, "kernelspec": { "display_name": "pytorch3d (local)", "language": "python", diff --git a/docs/tutorials/render_textured_meshes.ipynb b/docs/tutorials/render_textured_meshes.ipynb index 1b32145a..1f63307a 100644 --- a/docs/tutorials/render_textured_meshes.ipynb +++ b/docs/tutorials/render_textured_meshes.ipynb @@ -84,7 +84,7 @@ "from skimage.io import imread\n", "\n", "# Util function for loading meshes\n", - "from pytorch3d.io import load_objs_as_meshes\n", + "from pytorch3d.io import load_objs_as_meshes, load_obj\n", "\n", "# Data structures and functions for rendering\n", "from pytorch3d.structures import Meshes, Textures\n", @@ -97,7 +97,7 @@ " RasterizationSettings, \n", " MeshRenderer, \n", " MeshRasterizer, \n", - " TexturedSoftPhongShader\n", + " SoftPhongShader\n", ")\n", "\n", "# add path for demo utils functions \n", @@ -316,7 +316,7 @@ " cameras=cameras, \n", " raster_settings=raster_settings\n", " ),\n", - " shader=TexturedSoftPhongShader(\n", + " shader=SoftPhongShader(\n", " device=device, \n", " cameras=cameras,\n", " lights=lights\n", @@ -563,6 +563,9 @@ ], "metadata": { "accelerator": "GPU", + "anp_metadata": { + "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/render_textured_meshes.ipynb" + }, "bento_stylesheets": { "bento/extensions/flow/main.css": true, "bento/extensions/kernel_selector/main.css": true, @@ -575,6 +578,9 @@ "name": "render_textured_meshes.ipynb", "provenance": [] }, + "disseminate_notebook_info": { + "backup_notebook_id": "569222367081034" + }, "kernelspec": { "display_name": "pytorch3d (local)", "language": "python", diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 85c4ce68..722bde09 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -13,8 +13,8 @@ from pytorch3d.renderer import ( OpenGLPerspectiveCameras, PointLights, RasterizationSettings, + TexturesVertex, ) -from pytorch3d.structures import Textures class ShapeNetBase(torch.utils.data.Dataset): @@ -113,8 +113,8 @@ class ShapeNetBase(torch.utils.data.Dataset): """ paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) meshes = load_objs_as_meshes(paths, device=device, load_textures=False) - meshes.textures = Textures( - verts_rgb=torch.ones_like(meshes.verts_padded(), device=device) + meshes.textures = TexturesVertex( + verts_features=torch.ones_like(meshes.verts_padded(), device=device) ) cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) renderer = MeshRenderer( diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index 03de8213..3012526f 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -11,7 +11,8 @@ import numpy as np import torch from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas from pytorch3d.io.utils import _open_file -from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch +from pytorch3d.renderer import TexturesAtlas, TexturesUV +from pytorch3d.structures import Meshes, join_meshes_as_batch def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor: @@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): Args: faces_indices: List of ints of indices. max_index: Max index for the face property. + pad_value: if any of the face_indices are padded, specify + the value of the padding (e.g. -1). This is only used + for texture indices indices where there might + not be texture information for all the faces. Returns: faces_indices: List of ints of indices. @@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None): faces_indices[mask] = pad_value # Check indices are valid. - if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0): + if torch.any(faces_indices >= max_index) or ( + pad_value is None and torch.any(faces_indices < 0) + ): warnings.warn("Faces have invalid indices") return faces_indices @@ -227,7 +234,14 @@ def load_obj( ) -def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): +def load_objs_as_meshes( + files: list, + device=None, + load_textures: bool = True, + create_texture_atlas: bool = False, + texture_atlas_size: int = 4, + texture_wrap: Optional[str] = "repeat", +): """ Load meshes from a list of .obj files using the load_obj function, and return them as a Meshes object. This only works for meshes which have a @@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): """ mesh_list = [] for f_obj in files: - # TODO: update this function to support the two texturing options. - verts, faces, aux = load_obj(f_obj, load_textures=load_textures) - verts = verts.to(device) + verts, faces, aux = load_obj( + f_obj, + load_textures=load_textures, + create_texture_atlas=create_texture_atlas, + texture_atlas_size=texture_atlas_size, + texture_wrap=texture_wrap, + ) tex = None - tex_maps = aux.texture_images - if tex_maps is not None and len(tex_maps) > 0: - verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2) - faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3) - image = list(tex_maps.values())[0].to(device)[None] - tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image) + if create_texture_atlas: + # TexturesAtlas type + tex = TexturesAtlas(atlas=[aux.texture_atlas]) + else: + # TexturesUV type + tex_maps = aux.texture_images + if tex_maps is not None and len(tex_maps) > 0: + verts_uvs = aux.verts_uvs.to(device) # (V, 2) + faces_uvs = faces.textures_idx.to(device) # (F, 3) + image = list(tex_maps.values())[0].to(device)[None] + tex = TexturesUV( + verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image + ) - mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex) + mesh = Meshes( + verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex + ) mesh_list.append(mesh) if len(mesh_list) == 1: return mesh_list[0] diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 0e8b44e9..91b2a47f 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -28,11 +28,11 @@ from .mesh import ( SoftGouraudShader, SoftPhongShader, SoftSilhouetteShader, - TexturedSoftPhongShader, + Textures, + TexturesAtlas, + TexturesUV, + TexturesVertex, gouraud_shading, - interpolate_face_attributes, - interpolate_texture_map, - interpolate_vertex_colors, phong_shading, rasterize_meshes, ) diff --git a/pytorch3d/renderer/mesh/__init__.py b/pytorch3d/renderer/mesh/__init__.py index 22ce56dc..a0a01086 100644 --- a/pytorch3d/renderer/mesh/__init__.py +++ b/pytorch3d/renderer/mesh/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .texturing import interpolate_texture_map, interpolate_vertex_colors # isort:skip from .rasterize_meshes import rasterize_meshes from .rasterizer import MeshRasterizer, RasterizationSettings from .renderer import MeshRenderer +from .shader import TexturedSoftPhongShader # DEPRECATED from .shader import ( HardFlatShader, HardGouraudShader, @@ -12,10 +12,10 @@ from .shader import ( SoftGouraudShader, SoftPhongShader, SoftSilhouetteShader, - TexturedSoftPhongShader, ) from .shading import gouraud_shading, phong_shading -from .utils import interpolate_face_attributes +from .textures import Textures # DEPRECATED +from .textures import TexturesAtlas, TexturesUV, TexturesVertex __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index ecdc6ebe..b0945373 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import warnings import torch import torch.nn as nn @@ -13,7 +14,6 @@ from ..blending import ( from ..lighting import PointLights from ..materials import Materials from .shading import flat_shading, gouraud_shading, phong_shading -from .texturing import interpolate_texture_map, interpolate_vertex_colors # A Shader should take as input fragments from the output of rasterization @@ -57,7 +57,7 @@ class HardPhongShader(nn.Module): or in the forward pass of HardPhongShader" raise ValueError(msg) - texels = interpolate_vertex_colors(fragments, meshes) + texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) blend_params = kwargs.get("blend_params", self.blend_params) @@ -104,9 +104,11 @@ class SoftPhongShader(nn.Module): msg = "Cameras must be specified either at initialization \ or in the forward pass of SoftPhongShader" raise ValueError(msg) - texels = interpolate_vertex_colors(fragments, meshes) + + texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) + blend_params = kwargs.get("blend_params", self.blend_params) colors = phong_shading( meshes=meshes, fragments=fragments, @@ -115,7 +117,7 @@ class SoftPhongShader(nn.Module): cameras=cameras, materials=materials, ) - images = softmax_rgb_blend(colors, fragments, self.blend_params) + images = softmax_rgb_blend(colors, fragments, blend_params) return images @@ -154,6 +156,12 @@ class HardGouraudShader(nn.Module): lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) blend_params = kwargs.get("blend_params", self.blend_params) + + # As Gouraud shading applies the illumination to the vertex + # colors, the interpolated pixel texture is calculated in the + # shading step. In comparison, for Phong shading, the pixel + # textures are computed first after which the illumination is + # applied. pixel_colors = gouraud_shading( meshes=meshes, fragments=fragments, @@ -210,54 +218,25 @@ class SoftGouraudShader(nn.Module): return images -class TexturedSoftPhongShader(nn.Module): +def TexturedSoftPhongShader( + device="cpu", cameras=None, lights=None, materials=None, blend_params=None +): """ - Per pixel lighting applied to a texture map. First interpolate the vertex - uv coordinates and sample from a texture map. Then apply the lighting model - using the interpolated coords and normals for each pixel. - - The blending function returns the soft aggregated color using all - the faces per pixel. - - To use the default values, simply initialize the shader with the desired - device e.g. - - .. code-block:: - - shader = TexturedPhongShader(device=torch.device("cuda:0")) + TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. + Preserving TexturedSoftPhongShader as a function for backwards compatibility. """ - - def __init__( - self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None - ): - super().__init__() - self.lights = lights if lights is not None else PointLights(device=device) - self.materials = ( - materials if materials is not None else Materials(device=device) - ) - self.cameras = cameras - self.blend_params = blend_params if blend_params is not None else BlendParams() - - def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: - cameras = kwargs.get("cameras", self.cameras) - if cameras is None: - msg = "Cameras must be specified either at initialization \ - or in the forward pass of TexturedSoftPhongShader" - raise ValueError(msg) - texels = interpolate_texture_map(fragments, meshes) - lights = kwargs.get("lights", self.lights) - materials = kwargs.get("materials", self.materials) - blend_params = kwargs.get("blend_params", self.blend_params) - colors = phong_shading( - meshes=meshes, - fragments=fragments, - texels=texels, - lights=lights, - cameras=cameras, - materials=materials, - ) - images = softmax_rgb_blend(colors, fragments, blend_params) - return images + warnings.warn( + """TexturedSoftPhongShader is now deprecated; + use SoftPhongShader instead.""", + PendingDeprecationWarning, + ) + return SoftPhongShader( + device=device, + cameras=cameras, + lights=lights, + materials=materials, + blend_params=blend_params, + ) class HardFlatShader(nn.Module): @@ -291,7 +270,7 @@ class HardFlatShader(nn.Module): msg = "Cameras must be specified either at initialization \ or in the forward pass of HardFlatShader" raise ValueError(msg) - texels = interpolate_vertex_colors(fragments, meshes) + texels = meshes.sample_textures(fragments) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) blend_params = kwargs.get("blend_params", self.blend_params) diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index b6ff84ac..2d248823 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -6,6 +6,8 @@ from typing import Tuple import torch from pytorch3d.ops import interpolate_face_attributes +from .textures import TexturesVertex + def _apply_lighting( points, normals, lights, cameras, materials @@ -91,6 +93,9 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens Then interpolate the vertex shaded colors using the barycentric coordinates to get a color per pixel. + Gouraud shading is only supported for meshes with texture type `TexturesVertex`. + This is because the illumination is applied to the vertex colors. + Args: meshes: Batch of meshes fragments: Fragments named tuple with the outputs of rasterization @@ -101,10 +106,13 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens Returns: colors: (N, H, W, K, 3) """ + if not isinstance(meshes.textures, TexturesVertex): + raise ValueError("Mesh textures must be an instance of TexturesVertex") + faces = meshes.faces_packed() # (F, 3) - verts = meshes.verts_packed() - vertex_normals = meshes.verts_normals_packed() # (V, 3) - vertex_colors = meshes.textures.verts_rgb_packed() + verts = meshes.verts_packed() # (V, 3) + verts_normals = meshes.verts_normals_packed() # (V, 3) + verts_colors = meshes.textures.verts_features_packed() # (V, D) vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # Format properties of lights and materials so they are compatible @@ -119,9 +127,10 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens # Calculate the illumination at each vertex ambient, diffuse, specular = _apply_lighting( - verts, vertex_normals, lights, cameras, materials + verts, verts_normals, lights, cameras, materials ) - verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular + + verts_colors_shaded = verts_colors * (ambient + diffuse) + specular face_colors = verts_colors_shaded[faces] colors = interpolate_face_attributes( fragments.pix_to_face, fragments.bary_coords, face_colors diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py new file mode 100644 index 00000000..defd610c --- /dev/null +++ b/pytorch3d/renderer/mesh/textures.py @@ -0,0 +1,1049 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from pytorch3d.ops import interpolate_face_attributes +from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list +from torch.nn.functional import interpolate + + +# This file contains classes and helper functions for texturing. +# There are three types of textures: TexturesVertex, TexturesAtlas +# and TexturesUV which inherit from a base textures class TexturesBase. +# +# Each texture class has a method 'sample_textures' to sample a +# value given barycentric coordinates. +# +# All the textures accept either list or padded inputs. The values +# are stored as either per face values (TexturesAtlas, TexturesUV), +# or per face vertex features (TexturesVertex). + + +def _list_to_padded_wrapper( + x: List[torch.Tensor], + pad_size: Union[list, tuple, None] = None, + pad_value: float = 0.0, +) -> torch.Tensor: + r""" + This is a wrapper function for + pytorch3d.structures.utils.list_to_padded function which only accepts + 3-dimensional inputs. + + For this use case, the input x is of shape (F, 3, ...) where only F + is different for each element in the list + + Transforms a list of N tensors each of shape (Mi, ...) into a single tensor + of shape (N, pad_size, ...), or (N, max(Mi), ...) + if pad_size is None. + + Args: + x: list of Tensors + pad_size: int specifying the size of the first dimension + of the padded tensor + pad_value: float value to be used to fill the padded tensor + + Returns: + x_padded: tensor consisting of padded input tensors + """ + N = len(x) + dims = x[0].ndim + reshape_dims = x[0].shape[1:] + D = torch.prod(torch.tensor(reshape_dims)).item() + x_reshaped = [] + for y in x: + if y.ndim != dims and y.shape[1:] != reshape_dims: + msg = ( + "list_to_padded requires tensors to have the same number of dimensions" + ) + raise ValueError(msg) + x_reshaped.append(y.reshape(-1, D)) + x_padded = list_to_padded(x_reshaped, pad_size=pad_size, pad_value=pad_value) + return x_padded.reshape((N, -1) + reshape_dims) + + +def _padded_to_list_wrapper( + x: torch.Tensor, split_size: Union[list, tuple, None] = None +) -> List[torch.Tensor]: + r""" + This is a wrapper function for pytorch3d.structures.utils.padded_to_list + which only accepts 3-dimensional inputs. + + For this use case, the input x is of shape (N, F, ...) where F + is the number of faces which is different for each tensor in the batch. + + This function transforms a padded tensor of shape (N, M, ...) into a + list of N tensors of shape (Mi, ...) where (Mi) is specified in + split_size(i), or of shape (M,) if split_size is None. + + Args: + x: padded Tensor + split_size: list of ints defining the number of items for each tensor + in the output list. + + Returns: + x_list: a list of tensors + """ + N, M = x.shape[:2] + reshape_dims = x.shape[2:] + D = torch.prod(torch.tensor(reshape_dims)).item() + x_reshaped = x.reshape(N, M, D) + x_list = padded_to_list(x_reshaped, split_size=split_size) + x_list = [xl.reshape((xl.shape[0],) + reshape_dims) for xl in x_list] + return x_list + + +def _pad_texture_maps( + images: Union[Tuple[torch.Tensor], List[torch.Tensor]] +) -> torch.Tensor: + """ + Pad all texture images so they have the same height and width. + + Args: + images: list of N tensors of shape (H, W, 3) + + Returns: + tex_maps: Tensor of shape (N, max_H, max_W, 3) + """ + tex_maps = [] + max_H = 0 + max_W = 0 + for im in images: + h, w, _3 = im.shape + if h > max_H: + max_H = h + if w > max_W: + max_W = w + tex_maps.append(im) + max_shape = (max_H, max_W) + + for i, image in enumerate(tex_maps): + if image.shape[:2] != max_shape: + image_BCHW = image.permute(2, 0, 1)[None] + new_image_BCHW = interpolate( + image_BCHW, size=max_shape, mode="bilinear", align_corners=False + ) + tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) + tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) + return tex_maps + + +# A base class for defining a batch of textures +# with helper methods. +# This is also useful to have so that inside `Meshes` +# we can allow the input textures to be any texture +# type which is an instance of the base class. +class TexturesBase(object): + def __init__(self): + self._N = 0 + self.valid = None + + def isempty(self): + if self._N is not None and self.valid is not None: + return self._N == 0 or self.valid.eq(False).all() + return False + + def to(self, device): + for k in dir(self): + v = getattr(self, k) + if isinstance(v, (list, tuple)) and all( + torch.is_tensor(elem) for elem in v + ): + v = [elem.to(device) for elem in v] + setattr(self, k, v) + if torch.is_tensor(v) and v.device != device: + setattr(self, k, v.to(device)) + return self + + def _extend(self, N: int, props: List[str]) -> Dict[str, Union[torch.Tensor, List]]: + """ + Create a dict with the specified properties + repeated N times per batch element. + + Args: + N: number of new copies of each texture + in the batch. + props: a List of strings which refer to either + class attributes or class methods which + return tensors or lists. + + Returns: + Dict with the same keys as props. The values are the + extended properties. + """ + if not isinstance(N, int): + raise ValueError("N must be an integer.") + if N <= 0: + raise ValueError("N must be > 0.") + + new_props = {} + for p in props: + t = getattr(self, p) + if callable(t): + t = t() # class method + if isinstance(t, list): + if not all(isinstance(elem, (int, float)) for elem in t): + raise ValueError("Extend only supports lists of scalars") + t = [[ti] * N for ti in t] + new_props[p] = list(itertools.chain(*t)) + elif torch.is_tensor(t): + new_props[p] = t.repeat_interleave(N, dim=0) + return new_props + + def _getitem(self, index: Union[int, slice], props: List[str]): + """ + Helper function for __getitem__ + """ + new_props = {} + if isinstance(index, (int, slice)): + for p in props: + t = getattr(self, p) + if callable(t): + t = t() # class method + new_props[p] = t[index] + elif isinstance(index, list): + index = torch.tensor(index) + if isinstance(index, torch.Tensor): + if index.dtype == torch.bool: + index = index.nonzero() + index = index.squeeze(1) if index.numel() > 0 else index + index = index.tolist() + for p in props: + t = getattr(self, p) + if callable(t): + t = t() # class method + new_props[p] = [t[i] for i in index] + + return new_props + + def sample_textures(self): + """ + Different texture classes sample textures in different ways + e.g. for vertex textures, the values at each vertex + are interpolated across the face using the barycentric + coordinates. + Each texture class should implement a sample_textures + method to take the `fragments` from rasterization. + Using `fragments.pix_to_face` and `fragments.bary_coords` + this function should return the sampled texture values for + each pixel in the output image. + """ + raise NotImplementedError() + + def clone(self): + """ + Each texture class should implement a method + to clone all necessary internal tensors. + """ + raise NotImplementedError() + + def __getitem__(self, index): + """ + Each texture class should implement a method + to get the texture properites for the + specified elements in the batch. + The TexturesBase._getitem(i) method + can be used as a helper funtion to retrieve the + class attributes for item i. Then, a new + instance of the child class can be created with + the attributes. + """ + raise NotImplementedError() + + def __repr__(self): + return "TexturesBase" + + +def Textures( + maps: Union[List, torch.Tensor, None] = None, + faces_uvs: Optional[torch.Tensor] = None, + verts_uvs: Optional[torch.Tensor] = None, + verts_rgb: Optional[torch.Tensor] = None, +) -> TexturesBase: + """ + Textures class has been DEPRECATED. + Preserving Textures as a function for backwards compatibility. + + Args: + maps: texture map per mesh. This can either be a list of maps + [(H, W, 3)] or a padded tensor of shape (N, H, W, 3). + faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each + vertex in the face. Padding value is assumed to be -1. + verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex. + verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding + value is assumed to be -1. + + + Returns: + a Textures class which is an instance of TexturesBase e.g. TexturesUV, + TexturesAtlas, TexturesVerte + + """ + + warnings.warn( + """Textures class is deprecated, + use TexturesUV, TexturesAtlas, TexturesVertex instead. + Textures class will be removed in future releases.""", + PendingDeprecationWarning, + ) + + if all(x is not None for x in [faces_uvs, verts_uvs, maps]): + return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs) + elif verts_rgb is not None: + return TexturesVertex(verts_features=verts_rgb) + else: + raise ValueError( + "Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb" + ) + + +class TexturesAtlas(TexturesBase): + def __init__(self, atlas: Union[torch.Tensor, List, None]): + """ + A texture representation where each face has a square texture map. + This is based on the implementation from SoftRasterizer [1]. + + Args: + atlas: (N, F, R, R, D) tensor giving the per face texture map. + The atlas can be created during obj loading with the + pytorch3d.io.load_obj function - in the input arguments + set `create_texture_atlas=True`. The atlas will be + returned in aux.texture_atlas. + + + The padded and list representations of the textures are stored + and the packed representations is computed on the fly and + not cached. + + [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based + 3D Reasoning', ICCV 2019 + """ + if isinstance(atlas, (list, tuple)): + correct_format = all( + ( + torch.is_tensor(elem) + and elem.ndim == 4 + and elem.shape[1] == elem.shape[2] + ) + for elem in atlas + ) + if not correct_format: + msg = "Expected atlas to be a list of tensors of shape (F, R, R, D)" + raise ValueError(msg) + self._atlas_list = atlas + self._atlas_padded = None + self.device = torch.device("cpu") + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(atlas) + self._num_faces_per_mesh = [len(a) for a in atlas] + + if self._N > 0: + self.device = atlas[0].device + + elif torch.is_tensor(atlas): + if atlas.ndim != 5: + msg = "Expected atlas to be of shape (N, F, R, R, D); got %r" + raise ValueError(msg % repr(atlas.ndim)) + self._atlas_padded = atlas + self._atlas_list = None + self.device = atlas.device + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(atlas) + max_F = atlas.shape[1] + self._num_faces_per_mesh = [max_F] * self._N + else: + raise ValueError("Expected atlas to be a tensor or list") + + # The num_faces_per_mesh, N and valid + # are reset inside the Meshes object when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) + + # This is a hack to allow the child classes to also have the same representation + # as the parent. In meshes.py we check that the input textures have the correct + # type. However due to circular imports issues, we can't import the texture + # classes into any files in pytorch3d.structures. Instead we check + # for repr(textures) == "TexturesBase". + def __repr__(self): + return super().__repr__() + + def clone(self): + tex = self.__class__(atlas=self.atlas_padded().clone()) + num_faces = ( + self._num_faces_per_mesh.clone() + if torch.is_tensor(self._num_faces_per_mesh) + else self._num_faces_per_mesh + ) + tex.valid = self.valid.clone() + tex._num_faces_per_mesh = num_faces + return tex + + def __getitem__(self, index): + props = ["atlas_list", "_num_faces_per_mesh"] + new_props = self._getitem(index, props=props) + atlas = new_props["atlas_list"] + if isinstance(atlas, list): + # multiple batch elements + new_tex = self.__class__(atlas=atlas) + elif torch.is_tensor(atlas): + # single element + new_tex = self.__class__(atlas=[atlas]) + else: + raise ValueError("Not all values are provided in the correct format") + new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] + return new_tex + + def atlas_padded(self) -> torch.Tensor: + if self._atlas_padded is None: + if self.isempty(): + self._atlas_padded = torch.zeros( + (self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device + ) + else: + self._atlas_padded = _list_to_padded_wrapper( + self._atlas_list, pad_value=0.0 + ) + return self._atlas_padded + + def atlas_list(self) -> List[torch.Tensor]: + if self._atlas_list is None: + if self.isempty(): + self._atlas_padded = [ + torch.empty((0, 0, 0, 3), dtype=torch.float32, device=self.device) + ] * self._N + self._atlas_list = _padded_to_list_wrapper( + self._atlas_padded, split_size=self._num_faces_per_mesh + ) + return self._atlas_list + + def atlas_packed(self) -> torch.Tensor: + if self.isempty(): + return torch.zeros( + (self._N, 0, 0, 3), dtype=torch.float32, device=self.device + ) + atlas_list = self.atlas_list() + return list_to_packed(atlas_list)[0] + + def extend(self, N: int) -> "TexturesAtlas": + new_props = self._extend(N, ["atlas_padded", "_num_faces_per_mesh"]) + new_tex = TexturesAtlas(atlas=new_props["atlas_padded"]) + new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] + return new_tex + + def sample_textures(self, fragments, **kwargs) -> torch.Tensor: + """ + Args: + fragments: + The outputs of rasterization. From this we use + + - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + + Returns: + texels: (N, H, W, K, 3) + """ + N, H, W, K = fragments.pix_to_face.shape + atlas_packed = self.atlas_packed() + R = atlas_packed.shape[1] + bary = fragments.bary_coords + pix_to_face = fragments.pix_to_face + + bary_w01 = bary[..., :2] + mask = (pix_to_face < 0)[..., None] + bary_w01 = torch.where(mask, torch.zeros_like(bary_w01), bary_w01) + w_xy = (bary_w01 * R).to(torch.int64) # (N, H, W, K, 2) + + below_diag = ( + bary_w01.sum(dim=-1) * R - w_xy.float().sum(dim=-1) + ) <= 1.0 # (N, H, W, K) + w_x, w_y = w_xy.unbind(-1) + w_x = torch.where(below_diag, w_x, (R - 1 - w_x)) + w_y = torch.where(below_diag, w_y, (R - 1 - w_y)) + + texels = atlas_packed[pix_to_face, w_y, w_x] + texels = texels * (pix_to_face >= 0)[..., None].float() + + return texels + + def join_batch(self, textures: List["TexturesAtlas"]) -> "TexturesAtlas": + """ + Join the list of textures given by `textures` to + self to create a batch of textures. Return a new + TexturesAtlas object with the combined textures. + + Args: + textures: List of TextureAtlas objects + + Returns: + new_tex: TextureAtlas object with the combined + textures from self and the list `textures`. + """ + tex_types_same = all(isinstance(tex, TexturesAtlas) for tex in textures) + if not tex_types_same: + raise ValueError("All textures must be of type TexturesAtlas.") + + atlas_list = [] + atlas_list += self.atlas_list() + num_faces_per_mesh = self._num_faces_per_mesh + for tex in textures: + atlas_list += tex.atlas_list() + num_faces_per_mesh += tex._num_faces_per_mesh + new_tex = self.__class__(atlas=atlas_list) + new_tex._num_faces_per_mesh = num_faces_per_mesh + return new_tex + + +class TexturesUV(TexturesBase): + def __init__( + self, + maps: Union[torch.Tensor, List[torch.Tensor]], + faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + ): + """ + Textures are represented as a per mesh texture map and uv coordinates for each + vertex in each face. NOTE: this class only supports one texture map per mesh. + + Args: + maps: texture map per mesh. This can either be a list of maps + [(H, W, 3)] or a padded tensor of shape (N, H, W, 3) + faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face + verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex + + Note: only the padded and list representation of the textures are stored + and the packed representations is computed on the fly and + not cached. + """ + super().__init__() + if isinstance(faces_uvs, (list, tuple)): + for fv in faces_uvs: + if fv.ndim != 2 or fv.shape[-1] != 3: + msg = "Expected faces_uvs to be of shape (F, 3); got %r" + raise ValueError(msg % repr(fv.shape)) + self._faces_uvs_list = faces_uvs + self._faces_uvs_padded = None + self.device = torch.device("cpu") + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(faces_uvs) + self._num_faces_per_mesh = [len(fv) for fv in faces_uvs] + + if self._N > 0: + self.device = faces_uvs[0].device + + elif torch.is_tensor(faces_uvs): + if faces_uvs.ndim != 3 or faces_uvs.shape[-1] != 3: + msg = "Expected faces_uvs to be of shape (N, F, 3); got %r" + raise ValueError(msg % repr(faces_uvs.shape)) + self._faces_uvs_padded = faces_uvs + self._faces_uvs_list = None + self.device = faces_uvs.device + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(faces_uvs) + max_F = faces_uvs.shape[1] + self._num_faces_per_mesh = [max_F] * self._N + else: + raise ValueError("Expected faces_uvs to be a tensor or list") + + if isinstance(verts_uvs, (list, tuple)): + for fv in verts_uvs: + if fv.ndim != 2 or fv.shape[-1] != 2: + msg = "Expected verts_uvs to be of shape (V, 2); got %r" + raise ValueError(msg % repr(fv.shape)) + self._verts_uvs_list = verts_uvs + self._verts_uvs_padded = None + + if len(verts_uvs) != self._N: + raise ValueError( + "verts_uvs and faces_uvs must have the same batch dimension" + ) + if not all(v.device == self.device for v in verts_uvs): + import pdb + + pdb.set_trace() + raise ValueError("verts_uvs and faces_uvs must be on the same device") + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._num_verts_per_mesh = [len(v) for v in verts_uvs] + + elif torch.is_tensor(verts_uvs): + if ( + verts_uvs.ndim != 3 + or verts_uvs.shape[-1] != 2 + or verts_uvs.shape[0] != self._N + ): + msg = "Expected verts_uvs to be of shape (N, V, 2); got %r" + raise ValueError(msg % repr(verts_uvs.shape)) + self._verts_uvs_padded = verts_uvs + self._verts_uvs_list = None + + if verts_uvs.device != self.device: + raise ValueError("verts_uvs and faces_uvs must be on the same device") + + # These values may be overridden when textures is + # passed into the Meshes constructor. + max_V = verts_uvs.shape[1] + self._num_verts_per_mesh = [max_V] * self._N + else: + raise ValueError("Expected verts_uvs to be a tensor or list") + + if torch.is_tensor(maps): + if maps.ndim != 4 or maps.shape[0] != self._N: + msg = "Expected maps to be of shape (N, H, W, 3); got %r" + raise ValueError(msg % repr(maps.shape)) + self._maps_padded = maps + self._maps_list = None + elif isinstance(maps, (list, tuple)): + if len(maps) != self._N: + raise ValueError("Expected one texture map per mesh in the batch.") + self._maps_list = maps + if self._N > 0: + maps = _pad_texture_maps(maps) + else: + maps = torch.empty( + (self._N, 0, 0, 3), dtype=torch.float32, device=self.device + ) + self._maps_padded = maps + else: + raise ValueError("Expected maps to be a tensor or list.") + + if self._maps_padded.device != self.device: + raise ValueError("maps must be on the same device as verts/faces uvs.") + + self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) + + def __repr__(self): + return super().__repr__() + + def clone(self): + tex = self.__class__( + self.maps_padded().clone(), + self.faces_uvs_padded().clone(), + self.verts_uvs_padded().clone(), + ) + num_faces = ( + self._num_faces_per_mesh.clone() + if torch.is_tensor(self._num_faces_per_mesh) + else self._num_faces_per_mesh + ) + tex._num_faces_per_mesh = num_faces + tex.valid = self.valid.clone() + return tex + + def __getitem__(self, index): + props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"] + new_props = self._getitem(index, props) + faces_uvs = new_props["faces_uvs_list"] + verts_uvs = new_props["verts_uvs_list"] + maps = new_props["maps_list"] + + # if index has multiple values then faces/verts/maps may be a list of tensors + if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]): + new_tex = self.__class__( + faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps + ) + elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]): + new_tex = self.__class__( + faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps] + ) + else: + raise ValueError("Not all values are provided in the correct format") + new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] + return new_tex + + def faces_uvs_padded(self) -> torch.Tensor: + if self._faces_uvs_padded is None: + if self.isempty(): + self._faces_uvs_padded = torch.zeros( + (self._N, 0, 3), dtype=torch.float32, device=self.device + ) + else: + self._faces_uvs_padded = list_to_padded( + self._faces_uvs_list, pad_value=0.0 + ) + return self._faces_uvs_padded + + def faces_uvs_list(self) -> List[torch.Tensor]: + if self._faces_uvs_list is None: + if self.isempty(): + self._faces_uvs_list = [ + torch.empty((0, 3), dtype=torch.float32, device=self.device) + ] * self._N + else: + self._faces_uvs_list = padded_to_list( + self._faces_uvs_padded, split_size=self._num_faces_per_mesh + ) + return self._faces_uvs_list + + def faces_uvs_packed(self) -> torch.Tensor: + if self.isempty(): + return torch.zeros((self._N, 3), dtype=torch.float32, device=self.device) + faces_uvs_list = self.faces_uvs_list() + return list_to_packed(faces_uvs_list)[0] + + def verts_uvs_padded(self) -> torch.Tensor: + if self._verts_uvs_padded is None: + if self.isempty(): + self._verts_uvs_padded = torch.zeros( + (self._N, 0, 2), dtype=torch.float32, device=self.device + ) + else: + self._verts_uvs_padded = list_to_padded( + self._verts_uvs_list, pad_value=0.0 + ) + return self._verts_uvs_padded + + def verts_uvs_list(self) -> List[torch.Tensor]: + if self._verts_uvs_list is None: + if self.isempty(): + self._verts_uvs_list = [ + torch.empty((0, 2), dtype=torch.float32, device=self.device) + ] * self._N + else: + self._verts_uvs_list = padded_to_list( + self._verts_uvs_padded, split_size=self._num_verts_per_mesh + ) + return self._verts_uvs_list + + def verts_uvs_packed(self) -> torch.Tensor: + if self.isempty(): + return torch.zeros((self._N, 2), dtype=torch.float32, device=self.device) + verts_uvs_list = self.verts_uvs_list() + return list_to_packed(verts_uvs_list)[0] + + # Currently only the padded maps are used. + def maps_padded(self) -> torch.Tensor: + return self._maps_padded + + def maps_list(self) -> torch.Tensor: + # maps_list is not used anywhere currently - maps + # are padded to ensure the (H, W) of all maps is the + # same across the batch and we don't store the + # unpadded sizes of the maps. Therefore just return + # the unbinded padded tensor. + return self._maps_padded.unbind(0) + + def extend(self, N: int) -> "TexturesUV": + new_props = self._extend( + N, + [ + "maps_padded", + "verts_uvs_padded", + "faces_uvs_padded", + "_num_faces_per_mesh", + "_num_verts_per_mesh", + ], + ) + new_tex = TexturesUV( + maps=new_props["maps_padded"], + faces_uvs=new_props["faces_uvs_padded"], + verts_uvs=new_props["verts_uvs_padded"], + ) + new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] + new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"] + return new_tex + + def sample_textures(self, fragments, **kwargs) -> torch.Tensor: + """ + Interpolate a 2D texture map using uv vertex texture coordinates for each + face in the mesh. First interpolate the vertex uvs using barycentric coordinates + for each pixel in the rasterized output. Then interpolate the texture map + using the uv coordinate for each pixel. + + Args: + fragments: + The outputs of rasterization. From this we use + + - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + + Returns: + texels: tensor of shape (N, H, W, K, C) giving the interpolated + texture for each pixel in the rasterized image. + """ + verts_uvs = self.verts_uvs_packed() + faces_uvs = self.faces_uvs_packed() + faces_verts_uvs = verts_uvs[faces_uvs] + texture_maps = self.maps_padded() + + # pixel_uvs: (N, H, W, K, 2) + pixel_uvs = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs + ) + + N, H_out, W_out, K = fragments.pix_to_face.shape + N, H_in, W_in, C = texture_maps.shape # 3 for RGB + + # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) + pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) + + # textures.map: + # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) + # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) + texture_maps = ( + texture_maps.permute(0, 3, 1, 2)[None, ...] + .expand(K, -1, -1, -1, -1) + .transpose(0, 1) + .reshape(N * K, C, H_in, W_in) + ) + + # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) + # Now need to format the pixel uvs and the texture map correctly! + # From pytorch docs, grid_sample takes `grid` and `input`: + # grid specifies the sampling pixel locations normalized by + # the input spatial dimensions It should have most + # values in the range of [-1, 1]. Values x = -1, y = -1 + # is the left-top pixel of input, and values x = 1, y = 1 is the + # right-bottom pixel of input. + + pixel_uvs = pixel_uvs * 2.0 - 1.0 + texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map + if texture_maps.device != pixel_uvs.device: + texture_maps = texture_maps.to(pixel_uvs.device) + texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) + texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) + return texels + + def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV": + """ + Join the list of textures given by `textures` to + self to create a batch of textures. Return a new + TexturesUV object with the combined textures. + + Args: + textures: List of TexturesUV objects + + Returns: + new_tex: TexturesUV object with the combined + textures from self and the list `textures`. + """ + tex_types_same = all(isinstance(tex, TexturesUV) for tex in textures) + if not tex_types_same: + raise ValueError("All textures must be of type TexturesUV.") + + verts_uvs_list = [] + faces_uvs_list = [] + maps_list = [] + faces_uvs_list += self.faces_uvs_list() + verts_uvs_list += self.verts_uvs_list() + maps_list += list(self.maps_padded().unbind(0)) + num_faces_per_mesh = self._num_faces_per_mesh + for tex in textures: + verts_uvs_list += tex.verts_uvs_list() + faces_uvs_list += tex.faces_uvs_list() + num_faces_per_mesh += tex._num_faces_per_mesh + tex_map_list = list(tex.maps_padded().unbind(0)) + maps_list += tex_map_list + + new_tex = self.__class__( + maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list + ) + new_tex._num_faces_per_mesh = num_faces_per_mesh + return new_tex + + +class TexturesVertex(TexturesBase): + def __init__( + self, + verts_features: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + ): + """ + Batched texture representation where each vertex in a mesh + has a D dimensional feature vector. + + Args: + verts_features: (N, V, D) tensor giving a feature vector with + artbitrary dimensions for each vertex. + """ + if isinstance(verts_features, (tuple, list)): + correct_shape = all( + (torch.is_tensor(v) and v.ndim == 2) for v in verts_features + ) + if not correct_shape: + raise ValueError( + "Expected verts_features to be a list of tensors of shape (V, D)." + ) + + self._verts_features_list = verts_features + self._verts_features_padded = None + self.device = torch.device("cpu") + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(verts_features) + self._num_verts_per_mesh = [len(fv) for fv in verts_features] + + if self._N > 0: + self.device = verts_features[0].device + + elif torch.is_tensor(verts_features): + if verts_features.ndim != 3: + msg = "Expected verts_features to be of shape (N, V, D); got %r" + raise ValueError(msg % repr(verts_features.shape)) + self._verts_features_padded = verts_features + self._verts_features_list = None + self.device = verts_features.device + + # These values may be overridden when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self._N = len(verts_features) + max_F = verts_features.shape[1] + self._num_verts_per_mesh = [max_F] * self._N + else: + raise ValueError("verts_features must be a tensor or list of tensors") + + # This is set inside the Meshes object when textures is + # passed into the Meshes constructor. For more details + # refer to the __init__ of Meshes. + self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) + + def __repr__(self): + return super().__repr__() + + def clone(self): + tex = self.__class__(self.verts_features_padded().clone()) + if self._verts_features_list is not None: + tex._verts_features_list = [f.clone() for f in self._verts_features_list] + num_faces = ( + self._num_verts_per_mesh.clone() + if torch.is_tensor(self._num_verts_per_mesh) + else self._num_verts_per_mesh + ) + tex._num_verts_per_mesh = num_faces + tex.valid = self.valid.clone() + return tex + + def __getitem__(self, index): + props = ["verts_features_list", "_num_verts_per_mesh"] + new_props = self._getitem(index, props) + verts_features = new_props["verts_features_list"] + if isinstance(verts_features, list): + new_tex = self.__class__(verts_features=verts_features) + elif torch.is_tensor(verts_features): + new_tex = self.__class__(verts_features=[verts_features]) + else: + raise ValueError("Not all values are provided in the correct format") + new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"] + return new_tex + + def verts_features_padded(self) -> torch.Tensor: + if self._verts_features_padded is None: + if self.isempty(): + self._verts_features_padded = torch.zeros( + (self._N, 0, 3, 0), dtype=torch.float32, device=self.device + ) + else: + self._verts_features_padded = list_to_padded( + self._verts_features_list, pad_value=0.0 + ) + return self._verts_features_padded + + def verts_features_list(self) -> List[torch.Tensor]: + if self._verts_features_list is None: + if self.isempty(): + self._verts_features_list = [ + torch.empty((0, 3, 0), dtype=torch.float32, device=self.device) + ] * self._N + else: + self._verts_features_list = padded_to_list( + self._verts_features_padded, split_size=self._num_verts_per_mesh + ) + return self._verts_features_list + + def verts_features_packed(self) -> torch.Tensor: + if self.isempty(): + return torch.zeros((self._N, 3, 0), dtype=torch.float32, device=self.device) + verts_features_list = self.verts_features_list() + return list_to_packed(verts_features_list)[0] + + def extend(self, N: int) -> "TexturesVertex": + new_props = self._extend(N, ["verts_features_padded", "_num_verts_per_mesh"]) + new_tex = TexturesVertex(verts_features=new_props["verts_features_padded"]) + new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"] + return new_tex + + def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor: + """ + Detemine the color for each rasterized face. Interpolate the colors for + vertices which form the face using the barycentric coordinates. + Args: + fragments: + The outputs of rasterization. From this we use + + - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices + of the faces (in the packed representation) which + overlap each pixel in the image. + - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying + the barycentric coordianates of each pixel + relative to the faces (in the packed + representation) which overlap the pixel. + + Returns: + texels: An texture per pixel of shape (N, H, W, K, C). + There will be one C dimensional value for each element in + fragments.pix_to_face. + """ + verts_features_packed = self.verts_features_packed() + faces_verts_features = verts_features_packed[faces_packed] + + texels = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_features + ) + return texels + + def join_batch(self, textures: List["TexturesVertex"]) -> "TexturesVertex": + """ + Join the list of textures given by `textures` to + self to create a batch of textures. Return a new + TexturesVertex object with the combined textures. + + Args: + textures: List of TexturesVertex objects + + Returns: + new_tex: TexturesVertex object with the combined + textures from self and the list `textures`. + """ + tex_types_same = all(isinstance(tex, TexturesVertex) for tex in textures) + if not tex_types_same: + raise ValueError("All textures must be of type TexturesVertex.") + + verts_features_list = [] + verts_features_list += self.verts_features_list() + num_faces_per_mesh = self._num_verts_per_mesh + for tex in textures: + verts_features_list += tex.verts_features_list() + num_faces_per_mesh += tex._num_verts_per_mesh + + new_tex = self.__class__(verts_features=verts_features_list) + new_tex._num_verts_per_mesh = num_faces_per_mesh + return new_tex diff --git a/pytorch3d/renderer/mesh/texturing.py b/pytorch3d/renderer/mesh/texturing.py deleted file mode 100644 index b2ac7eba..00000000 --- a/pytorch3d/renderer/mesh/texturing.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - - -import torch -import torch.nn.functional as F -from pytorch3d.ops import interpolate_face_attributes -from pytorch3d.structures.textures import Textures - - -def interpolate_texture_map(fragments, meshes) -> torch.Tensor: - """ - Interpolate a 2D texture map using uv vertex texture coordinates for each - face in the mesh. First interpolate the vertex uvs using barycentric coordinates - for each pixel in the rasterized output. Then interpolate the texture map - using the uv coordinate for each pixel. - - Args: - fragments: - The outputs of rasterization. From this we use - - - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices - of the faces (in the packed representation) which - overlap each pixel in the image. - - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying - the barycentric coordianates of each pixel - relative to the faces (in the packed - representation) which overlap the pixel. - meshes: Meshes representing a batch of meshes. It is expected that - meshes has a textures attribute which is an instance of the - Textures class. - - Returns: - texels: tensor of shape (N, H, W, K, C) giving the interpolated - texture for each pixel in the rasterized image. - """ - if not isinstance(meshes.textures, Textures): - msg = "Expected meshes.textures to be an instance of Textures; got %r" - raise ValueError(msg % type(meshes.textures)) - - faces_uvs = meshes.textures.faces_uvs_packed() - verts_uvs = meshes.textures.verts_uvs_packed() - faces_verts_uvs = verts_uvs[faces_uvs] - texture_maps = meshes.textures.maps_padded() - - # pixel_uvs: (N, H, W, K, 2) - pixel_uvs = interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs - ) - - N, H_out, W_out, K = fragments.pix_to_face.shape - N, H_in, W_in, C = texture_maps.shape # 3 for RGB - - # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) - pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) - - # textures.map: - # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) - # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) - texture_maps = ( - texture_maps.permute(0, 3, 1, 2)[None, ...] - .expand(K, -1, -1, -1, -1) - .transpose(0, 1) - .reshape(N * K, C, H_in, W_in) - ) - - # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) - # Now need to format the pixel uvs and the texture map correctly! - # From pytorch docs, grid_sample takes `grid` and `input`: - # grid specifies the sampling pixel locations normalized by - # the input spatial dimensions It should have most - # values in the range of [-1, 1]. Values x = -1, y = -1 - # is the left-top pixel of input, and values x = 1, y = 1 is the - # right-bottom pixel of input. - - pixel_uvs = pixel_uvs * 2.0 - 1.0 - texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map - if texture_maps.device != pixel_uvs.device: - texture_maps = texture_maps.to(pixel_uvs.device) - texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) - texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) - return texels - - -def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: - """ - Detemine the color for each rasterized face. Interpolate the colors for - vertices which form the face using the barycentric coordinates. - Args: - meshes: A Meshes class representing a batch of meshes. - fragments: - The outputs of rasterization. From this we use - - - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices - of the faces (in the packed representation) which - overlap each pixel in the image. - - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying - the barycentric coordianates of each pixel - relative to the faces (in the packed - representation) which overlap the pixel. - - Returns: - texels: An texture per pixel of shape (N, H, W, K, C). - There will be one C dimensional value for each element in - fragments.pix_to_face. - """ - vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C) - vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] - faces_packed = meshes.faces_packed() - faces_textures = vertex_textures[faces_packed] # (F, 3, C) - texels = interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, faces_textures - ) - return texels diff --git a/pytorch3d/structures/__init__.py b/pytorch3d/structures/__init__.py index 67cca9fb..78d24a26 100644 --- a/pytorch3d/structures/__init__.py +++ b/pytorch3d/structures/__init__.py @@ -2,7 +2,6 @@ from .meshes import Meshes, join_meshes_as_batch from .pointclouds import Pointclouds -from .textures import Textures from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 4746cc9d..2bcd6a56 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -5,7 +5,6 @@ from typing import List, Union import torch from . import utils as struct_utils -from .textures import Textures class Meshes(object): @@ -234,9 +233,9 @@ class Meshes(object): Refer to comments above for descriptions of List and Padded representations. """ self.device = None - if textures is not None and not isinstance(textures, Textures): - msg = "Expected textures to be of type Textures; got %r" - raise ValueError(msg % type(textures)) + if textures is not None and not repr(textures) == "TexturesBase": + msg = "Expected textures to be an instance of type TexturesBase; got %r" + raise ValueError(msg % repr(textures)) self.textures = textures # Indicates whether the meshes in the list/batch have the same number @@ -400,6 +399,8 @@ class Meshes(object): if self.textures is not None: self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist() self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist() + self.textures._N = self._N + self.textures.valid = self.valid def __len__(self): return self._N @@ -1465,6 +1466,17 @@ class Meshes(object): return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex) + def sample_textures(self, fragments): + if self.textures is not None: + # Pass in faces packed. If the textures are defined per + # vertex, the face indices are needed in order to interpolate + # the vertex attributes across the face. + return self.textures.sample_textures( + fragments, faces_packed=self.faces_packed() + ) + else: + raise ValueError("Meshes does not have textures") + def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): """ @@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): raise ValueError("Inconsistent textures in join_meshes_as_batch.") # Now we know there are multiple meshes and they have textures to merge. - first = meshes[0].textures - kwargs = {} - if first.maps_padded() is not None: - if any(mesh.textures.maps_padded() is None for mesh in meshes): - raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.") - maps = [m for mesh in meshes for m in mesh.textures.maps_padded()] - kwargs["maps"] = maps - elif any(mesh.textures.maps_padded() is not None for mesh in meshes): - raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.") + all_textures = [mesh.textures for mesh in meshes] + first = all_textures[0] + tex_types_same = all(type(tex) == type(first) for tex in all_textures) - if first.verts_uvs_padded() is not None: - if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes): - raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.") - uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()] - V = max(uv.shape[0] for uv in uvs) - kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1) - elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes): - raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.") + if not tex_types_same: + raise ValueError("All meshes in the batch must have the same type of texture.") - if first.faces_uvs_padded() is not None: - if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes): - raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.") - uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()] - F = max(uv.shape[0] for uv in uvs) - kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1) - elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes): - raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.") - - if first.verts_rgb_padded() is not None: - if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes): - raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.") - rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()] - V = max(i.shape[0] for i in rgb) - kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3)) - elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes): - raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.") - - tex = Textures(**kwargs) + tex = first.join_batch(all_textures[1:]) return Meshes(verts=verts, faces=faces, textures=tex) @@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes: """ Joins a batch of meshes in the form of a Meshes object or a list of Meshes objects as a single mesh. If the input is a list, the Meshes objects in the list - must all be on the same device. This version ignores all textures in the input mehses. + must all be on the same device. This version ignores all textures in the input meshes. Args: meshes: Meshes object that contains a batch of meshes or a list of Meshes objects diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py deleted file mode 100644 index 25129108..00000000 --- a/pytorch3d/structures/textures.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - -from typing import List, Optional, Union - -import torch -from torch.nn.functional import interpolate - -from .utils import padded_to_list, padded_to_packed - - -""" -This file has functions for interpolating textures after rasterization. -""" - - -def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor: - """ - Pad all texture images so they have the same height and width. - - Args: - images: list of N tensors of shape (H, W, 3) - - Returns: - tex_maps: Tensor of shape (N, max_H, max_W, 3) - """ - tex_maps = [] - max_H = 0 - max_W = 0 - for im in images: - h, w, _3 = im.shape - if h > max_H: - max_H = h - if w > max_W: - max_W = w - tex_maps.append(im) - max_shape = (max_H, max_W) - - for i, image in enumerate(tex_maps): - if image.shape[:2] != max_shape: - image_BCHW = image.permute(2, 0, 1)[None] - new_image_BCHW = interpolate( - image_BCHW, size=max_shape, mode="bilinear", align_corners=False - ) - tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) - tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) - return tex_maps - - -def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor: - """ - Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch - dimension. This is done in the following sequence of steps (where `B` is - the batch dimension): - - .. code-block:: python - - input_tensor (B, ...) - -> add leading empty dimension (1, B, ...) - -> expand (N, B, ...) - -> reshape (N * B, ...) - - Args: - input_tensor: torch.Tensor with ndim > 2 representing a batched input. - N: number of times to extend each element of the batch. - """ - # pyre-fixme[16]: `Tensor` has no attribute `ndim`. - if input_tensor.ndim < 2: - raise ValueError("Input tensor must have ndimensions >= 2.") - B = input_tensor.shape[0] - non_batch_dims = tuple(input_tensor.shape[1:]) - constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded. - return ( - input_tensor.clone()[None, ...] - .expand(N, *constant_dims) - .transpose(0, 1) - .reshape(N * B, *non_batch_dims) - ) - - -class Textures(object): - def __init__( - self, - maps: Union[List, torch.Tensor, None] = None, - faces_uvs: Optional[torch.Tensor] = None, - verts_uvs: Optional[torch.Tensor] = None, - verts_rgb: Optional[torch.Tensor] = None, - ): - """ - Args: - maps: texture map per mesh. This can either be a list of maps - [(H, W, 3)] or a padded tensor of shape (N, H, W, 3). - faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each - vertex in the face. Padding value is assumed to be -1. - verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex. - verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding - value is assumed to be -1. - - Note: only the padded representation of the textures is stored - and the packed/list representations are computed on the fly and - not cached. - """ - # pyre-fixme[16]: `Tensor` has no attribute `ndim`. - if faces_uvs is not None and faces_uvs.ndim != 3: - msg = "Expected faces_uvs to be of shape (N, F, 3); got %r" - raise ValueError(msg % repr(faces_uvs.shape)) - if verts_uvs is not None and verts_uvs.ndim != 3: - msg = "Expected verts_uvs to be of shape (N, V, 2); got %r" - raise ValueError(msg % repr(verts_uvs.shape)) - if verts_rgb is not None and verts_rgb.ndim != 3: - msg = "Expected verts_rgb to be of shape (N, V, 3); got %r" - raise ValueError(msg % repr(verts_rgb.shape)) - if maps is not None: - # pyre-fixme[16]: `List` has no attribute `ndim`. - if torch.is_tensor(maps) and maps.ndim != 4: - msg = "Expected maps to be of shape (N, H, W, 3); got %r" - # pyre-fixme[16]: `List` has no attribute `shape`. - raise ValueError(msg % repr(maps.shape)) - elif isinstance(maps, list): - maps = _pad_texture_maps(maps) - if faces_uvs is None or verts_uvs is None: - msg = "To use maps, faces_uvs and verts_uvs are required" - raise ValueError(msg) - - self._faces_uvs_padded = faces_uvs - self._verts_uvs_padded = verts_uvs - self._verts_rgb_padded = verts_rgb - self._maps_padded = maps - - # The number of faces/verts for each mesh is - # set inside the Meshes object when textures is - # passed into the Meshes constructor. - self._num_faces_per_mesh = None - self._num_verts_per_mesh = None - - def clone(self): - other = self.__class__() - for k in dir(self): - v = getattr(self, k) - if torch.is_tensor(v): - setattr(other, k, v.clone()) - return other - - def to(self, device): - for k in dir(self): - v = getattr(self, k) - if torch.is_tensor(v) and v.device != device: - setattr(self, k, v.to(device)) - return self - - def __getitem__(self, index): - other = self.__class__() - for key in dir(self): - value = getattr(self, key) - if torch.is_tensor(value): - if isinstance(index, int): - setattr(other, key, value[index][None]) - else: - setattr(other, key, value[index]) - return other - - def faces_uvs_padded(self) -> torch.Tensor: - # pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`. - return self._faces_uvs_padded - - def faces_uvs_list(self) -> Union[List[torch.Tensor], None]: - if self._faces_uvs_padded is None: - return None - return padded_to_list( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - self._faces_uvs_padded, - split_size=self._num_faces_per_mesh, - ) - - def faces_uvs_packed(self) -> Union[torch.Tensor, None]: - if self._faces_uvs_padded is None: - return None - return padded_to_packed( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - self._faces_uvs_padded, - split_size=self._num_faces_per_mesh, - ) - - def verts_uvs_padded(self) -> Union[torch.Tensor, None]: - return self._verts_uvs_padded - - def verts_uvs_list(self) -> Union[List[torch.Tensor], None]: - if self._verts_uvs_padded is None: - return None - # Vertices shared between multiple faces - # may have a different uv coordinate for - # each face so the num_verts_uvs_per_mesh - # may be different from num_verts_per_mesh. - # Therefore don't use any split_size. - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - return padded_to_list(self._verts_uvs_padded) - - def verts_uvs_packed(self) -> Union[torch.Tensor, None]: - if self._verts_uvs_padded is None: - return None - # Vertices shared between multiple faces - # may have a different uv coordinate for - # each face so the num_verts_uvs_per_mesh - # may be different from num_verts_per_mesh. - # Therefore don't use any split_size. - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - return padded_to_packed(self._verts_uvs_padded) - - def verts_rgb_padded(self) -> Union[torch.Tensor, None]: - return self._verts_rgb_padded - - def verts_rgb_list(self) -> Union[List[torch.Tensor], None]: - if self._verts_rgb_padded is None: - return None - return padded_to_list( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - self._verts_rgb_padded, - split_size=self._num_verts_per_mesh, - ) - - def verts_rgb_packed(self) -> Union[torch.Tensor, None]: - if self._verts_rgb_padded is None: - return None - return padded_to_packed( - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - self._verts_rgb_padded, - split_size=self._num_verts_per_mesh, - ) - - # Currently only the padded maps are used. - def maps_padded(self) -> Union[torch.Tensor, None]: - # pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None, - # List[typing.Any], torch.Tensor]`. - return self._maps_padded - - def extend(self, N: int) -> "Textures": - """ - Create new Textures class which contains each input texture N times - - Args: - N: number of new copies of each texture. - - Returns: - new Textures object. - """ - if not isinstance(N, int): - raise ValueError("N must be an integer.") - if N <= 0: - raise ValueError("N must be > 0.") - - if all( - v is not None - for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded] - ): - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N) - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N) - # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None, - # List[typing.Any], torch.Tensor]`. - new_maps = _extend_tensor(self._maps_padded, N) - return self.__class__( - verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps - ) - elif self._verts_rgb_padded is not None: - # pyre-fixme[6]: Expected `Tensor` for 1st param but got - # `Optional[torch.Tensor]`. - new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N) - return self.__class__(verts_rgb=new_verts_rgb) - else: - msg = "Either vertex colors or texture maps are required." - raise ValueError(msg) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index b71f8394..a130db93 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None) # pyre-fixme[16]: `Tensor` has no attribute `ndim`. if x.ndim != 3: raise ValueError("Supports only 3-dimensional input tensors") + x_list = list(x.unbind(0)) if split_size is None: diff --git a/tests/data/test_texture_atlas_8x8_back.png b/tests/data/test_texture_atlas_8x8_back.png new file mode 100644 index 00000000..015f34f1 Binary files /dev/null and b/tests/data/test_texture_atlas_8x8_back.png differ diff --git a/tests/test_interpolate_face_attributes.py b/tests/test_interpolate_face_attributes.py index e943f3ff..ef5f57e4 100644 --- a/tests/test_interpolate_face_attributes.py +++ b/tests/test_interpolate_face_attributes.py @@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import ( interpolate_face_attributes, interpolate_face_attributes_python, ) +from pytorch3d.renderer.mesh import TexturesVertex from pytorch3d.renderer.mesh.rasterizer import Fragments -from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors -from pytorch3d.structures import Meshes, Textures +from pytorch3d.structures import Meshes class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): @@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3) def test_interpolate_attributes(self): - """ - This tests both interpolate_vertex_colors as well as - interpolate_face_attributes. - """ verts = torch.randn((4, 3), dtype=torch.float32) faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) vert_tex = torch.tensor( [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 ) - tex = Textures(verts_rgb=vert_tex[None, :]) + tex = TexturesVertex(verts_features=vert_tex[None, :]) mesh = Meshes(verts=[verts], faces=[faces], textures=tex) pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) barycentric_coords = torch.tensor( @@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): zbuf=torch.ones_like(pix_to_face), dists=torch.ones_like(pix_to_face), ) - texels = interpolate_vertex_colors(fragments, mesh) + + verts_features_packed = mesh.textures.verts_features_packed() + faces_verts_features = verts_features_packed[mesh.faces_packed()] + + texels = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_features + ) self.assertTrue(torch.allclose(texels, expected_vals[None, :])) def test_interpolate_attributes_grad(self): @@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): dtype=torch.float32, requires_grad=True, ) - tex = Textures(verts_rgb=vert_tex[None, :]) + tex = TexturesVertex(verts_features=vert_tex[None, :]) mesh = Meshes(verts=[verts], faces=[faces], textures=tex) pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) barycentric_coords = torch.tensor( @@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase): [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]], dtype=torch.float32, ) - texels = interpolate_vertex_colors(fragments, mesh) + verts_features_packed = mesh.textures.verts_features_packed() + faces_verts_features = verts_features_packed[mesh.faces_packed()] + + texels = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts_features + ) texels.sum().backward() self.assertTrue(hasattr(vert_tex, "grad")) self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :])) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 43f2a411..f609556a 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -13,8 +13,8 @@ from pytorch3d.io.mtl_io import ( _bilinear_interpolation_grid_sample, _bilinear_interpolation_vectorized, ) -from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch -from pytorch3d.structures.meshes import join_mesh +from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex +from pytorch3d.structures import Meshes, join_meshes_as_batch from pytorch3d.utils import torus @@ -590,17 +590,29 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): check_item(mesh.verts_padded(), mesh3.verts_padded()) check_item(mesh.faces_padded(), mesh3.faces_padded()) + if mesh.textures is not None: - check_item(mesh.textures.maps_padded(), mesh3.textures.maps_padded()) - check_item( - mesh.textures.faces_uvs_padded(), mesh3.textures.faces_uvs_padded() - ) - check_item( - mesh.textures.verts_uvs_padded(), mesh3.textures.verts_uvs_padded() - ) - check_item( - mesh.textures.verts_rgb_padded(), mesh3.textures.verts_rgb_padded() - ) + if isinstance(mesh.textures, TexturesUV): + check_item( + mesh.textures.faces_uvs_padded(), + mesh3.textures.faces_uvs_padded(), + ) + check_item( + mesh.textures.verts_uvs_padded(), + mesh3.textures.verts_uvs_padded(), + ) + check_item( + mesh.textures.maps_padded(), mesh3.textures.maps_padded() + ) + elif isinstance(mesh.textures, TexturesVertex): + check_item( + mesh.textures.verts_features_padded(), + mesh3.textures.verts_features_padded(), + ) + elif isinstance(mesh.textures, TexturesAtlas): + check_item( + mesh.textures.atlas_padded(), mesh3.textures.atlas_padded() + ) DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_filename = DATA_DIR / "cow_mesh/cow.obj" @@ -623,16 +635,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): check_triple(mesh_notex, mesh3_notex) self.assertIsNone(mesh_notex.textures) + # meshes with vertex texture, join into a batch. verts = torch.randn((4, 3), dtype=torch.float32) faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) - vert_tex = torch.tensor( - [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 - ) - tex = Textures(verts_rgb=vert_tex[None, :]) - mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex) + vert_tex = torch.ones_like(verts) + rgb_tex = TexturesVertex(verts_features=[vert_tex]) + mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex) mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb]) check_triple(mesh_rgb, mesh_rgb3) + # meshes with texture atlas, join into a batch. + device = "cuda:0" + atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device) + atlas_tex = TexturesAtlas(atlas=[atlas]) + mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex) + mesh_atlas3 = join_meshes_as_batch([mesh_atlas, mesh_atlas, mesh_atlas]) + check_triple(mesh_atlas, mesh_atlas3) + + # Test load multiple meshes with textures into a batch. teapot_obj = DATA_DIR / "teapot.obj" mesh_teapot = load_objs_as_meshes([teapot_obj]) teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0) @@ -649,41 +669,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0]) self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0]) - def test_join_meshes(self): - """ - Test that join_mesh joins single meshes and the corresponding values are - consistent with the single meshes. - """ - - # Load cow mesh. - DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" - cow_obj = DATA_DIR / "cow_mesh/cow.obj" - - cow_mesh = load_objs_as_meshes([cow_obj]) - cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0) - # Join a batch of three single meshes and check that the values are consistent - # with the individual meshes. - cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh]) - - def check_item(x, y, offset): - self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y) - - check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0) - check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V) - - # Test the joining of meshes of different sizes. - teapot_obj = DATA_DIR / "teapot.obj" - teapot_mesh = load_objs_as_meshes([teapot_obj]) - teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0) - - mix_mesh = join_mesh([cow_mesh, teapot_mesh]) - mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0) - self.assertEqual(len(mix_mesh), 1) - - self.assertClose(mix_verts[: cow_mesh._V], cow_verts) - self.assertClose(mix_faces[: cow_mesh._F], cow_faces) - self.assertClose(mix_verts[cow_mesh._V :], teapot_verts) - self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V) + # Check error raised if all meshes in the batch don't have the same texture type + with self.assertRaisesRegex(ValueError, "same type of texture"): + join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas]) @staticmethod def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int): diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index c4325e30..be5b9ec6 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -11,10 +11,11 @@ import numpy as np import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image -from pytorch3d.io import load_objs_as_meshes +from pytorch3d.io import load_obj from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.materials import Materials +from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.mesh.renderer import MeshRenderer from pytorch3d.renderer.mesh.shader import ( @@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import ( SoftSilhouetteShader, TexturedSoftPhongShader, ) -from pytorch3d.renderer.mesh.texturing import Textures from pytorch3d.structures.meshes import Meshes, join_mesh from pytorch3d.utils.ico_sphere import ico_sphere @@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): sphere_mesh = ico_sphere(5, device) verts_padded = sphere_mesh.verts_padded() faces_padded = sphere_mesh.faces_padded() - textures = Textures(verts_rgb=torch.ones_like(verts_padded)) + feats = torch.ones_like(verts_padded, device=device) + textures = TexturesVertex(verts_features=feats) sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures) # Init rasterizer settings @@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): filename = "simple_sphere_light_%s%s.png" % (name, postfix) image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: filename = "DEBUG_%s" % filename Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( @@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): Test a mesh with vertex textures can be extended to form a batch, and is rendered correctly with Phong, Gouraud and Flat Shaders. """ - batch_size = 20 + batch_size = 5 device = torch.device("cuda:0") # Init mesh with vertex textures. sphere_meshes = ico_sphere(5, device).extend(batch_size) verts_padded = sphere_meshes.verts_padded() faces_padded = sphere_meshes.faces_padded() - textures = Textures(verts_rgb=torch.ones_like(verts_padded)) + feats = torch.ones_like(verts_padded, device=device) + textures = TexturesVertex(verts_features=feats) sphere_meshes = Meshes( verts=verts_padded, faces=faces_padded, textures=textures ) @@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ) for i in range(batch_size): rgb = images[i, ..., :3].squeeze().cpu() + if i == 0 and DEBUG: + filename = "DEBUG_simple_sphere_batched_%s.png" % name + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) self.assertClose(rgb, image_ref, atol=0.05) def test_silhouette_with_grad(self): @@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): with Image.open(image_ref_filename) as raw_image_ref: image_ref = torch.from_numpy(np.array(raw_image_ref)) + image_ref = image_ref.to(dtype=torch.float32) / 255.0 self.assertClose(alpha, image_ref, atol=0.055) @@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): obj_filename = obj_dir / "cow_mesh/cow.obj" # Load mesh + texture - mesh = load_objs_as_meshes([obj_filename], device=device) + verts, faces, aux = load_obj( + obj_filename, device=device, load_textures=True, texture_wrap=None + ) + tex_map = list(aux.texture_images.values())[0] + tex_map = tex_map[None, ...].to(faces.textures_idx.device) + textures = TexturesUV( + maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs] + ) + mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures) # Init rasterizer settings R, T = look_at_view_transform(2.7, 0, 0) cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1 ) @@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): Meshes(verts=verts, faces=sphere_list[i].faces_padded()) ) joined_sphere_mesh = join_mesh(sphere_mesh_list) - joined_sphere_mesh.textures = Textures( - verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded()) + joined_sphere_mesh.textures = TexturesVertex( + verts_features=torch.ones_like(joined_sphere_mesh.verts_padded()) ) # Init rasterizer settings @@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ) image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR) self.assertClose(rgb, image_ref, atol=0.05) + + def test_texture_map_atlas(self): + """ + Test a mesh with a texture map as a per face atlas is loaded and rendered correctly. + """ + device = torch.device("cuda:0") + obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" + obj_filename = obj_dir / "cow_mesh/cow.obj" + + # Load mesh and texture as a per face texture atlas. + verts, faces, aux = load_obj( + obj_filename, + device=device, + load_textures=True, + create_texture_atlas=True, + texture_atlas_size=8, + texture_wrap=None, + ) + mesh = Meshes( + verts=[verts], + faces=[faces.verts_idx], + textures=TexturesAtlas(atlas=[aux.texture_atlas]), + ) + + # Init rasterizer settings + R, T = look_at_view_transform(2.7, 0, 0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True + ) + + # Init shader settings + materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0) + lights = PointLights(device=device) + + # Place light behind the cow in world space. The front of + # the cow is facing the -z direction. + lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None] + + # The HardPhongShader can be used directly with atlas textures. + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials), + ) + + images = renderer(mesh) + rgb = images[0, ..., :3].squeeze().cpu() + + # Load reference image + image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR) + + if DEBUG: + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_texture_atlas_8x8_back.png" + ) + + self.assertClose(rgb, image_ref, atol=0.05) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index c1abfbbf..e71ff760 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -7,14 +7,376 @@ import torch import torch.nn.functional as F from common_testing import TestCaseMixin from pytorch3d.renderer.mesh.rasterizer import Fragments -from pytorch3d.renderer.mesh.texturing import interpolate_texture_map -from pytorch3d.structures import Meshes, Textures -from pytorch3d.structures.utils import list_to_padded +from pytorch3d.renderer.mesh.textures import ( + TexturesAtlas, + TexturesUV, + TexturesVertex, + _list_to_padded_wrapper, +) +from pytorch3d.structures import Meshes, list_to_packed, packed_to_list from test_meshes import TestMeshes -class TestTexturing(TestCaseMixin, unittest.TestCase): - def test_interpolate_texture_map(self): +def tryindex(self, index, tex, meshes, source): + tex2 = tex[index] + meshes2 = meshes[index] + tex_from_meshes = meshes2.textures + for item in source: + basic = source[item][index] + from_texture = getattr(tex2, item + "_padded")() + from_meshes = getattr(tex_from_meshes, item + "_padded")() + if isinstance(index, int): + basic = basic[None] + + if len(basic) == 0: + self.assertEquals(len(from_texture), 0) + self.assertEquals(len(from_meshes), 0) + else: + self.assertClose(basic, from_texture) + self.assertClose(basic, from_meshes) + self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim) + item_list = getattr(tex_from_meshes, item + "_list")() + self.assertEqual(basic.shape[0], len(item_list)) + for i, elem in enumerate(item_list): + self.assertClose(elem, basic[i]) + + +class TestTexturesVertex(TestCaseMixin, unittest.TestCase): + def test_sample_vertex_textures(self): + """ + This tests both interpolate_vertex_colors as well as + interpolate_face_attributes. + """ + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + vert_tex = torch.tensor( + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32 + ) + verts_features = vert_tex + tex = TexturesVertex(verts_features=[verts_features]) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + expected_vals = torch.tensor( + [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + # sample_textures calls interpolate_vertex_colors + texels = mesh.sample_textures(fragments) + self.assertTrue(torch.allclose(texels, expected_vals[None, :])) + + def test_sample_vertex_textures_grad(self): + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + vert_tex = torch.tensor( + [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], + dtype=torch.float32, + requires_grad=True, + ) + verts_features = vert_tex + tex = TexturesVertex(verts_features=[verts_features]) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + grad_vert_tex = torch.tensor( + [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]], + dtype=torch.float32, + ) + texels = mesh.sample_textures(fragments) + texels.sum().backward() + self.assertTrue(hasattr(vert_tex, "grad")) + self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :])) + + def test_textures_vertex_init_fail(self): + # Incorrect sized tensors + with self.assertRaisesRegex(ValueError, "verts_features"): + TexturesVertex(verts_features=torch.rand(size=(5, 10))) + + # Not a list or a tensor + with self.assertRaisesRegex(ValueError, "verts_features"): + TexturesVertex(verts_features=(1, 1, 1)) + + def test_clone(self): + tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128))) + tex_cloned = tex.clone() + self.assertSeparate( + tex._verts_features_padded, tex_cloned._verts_features_padded + ) + self.assertSeparate(tex.valid, tex_cloned.valid) + + def test_extend(self): + B = 10 + mesh = TestMeshes.init_mesh(B, 30, 50) + V = mesh._V + tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3))) + tex_mesh = Meshes( + verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv + ) + N = 20 + new_mesh = tex_mesh.extend(N) + + self.assertEqual(len(tex_mesh) * N, len(new_mesh)) + + tex_init = tex_mesh.textures + new_tex = new_mesh.textures + + for i in range(len(tex_mesh)): + for n in range(N): + self.assertClose( + tex_init.verts_features_list()[i], + new_tex.verts_features_list()[i * N + n], + ) + self.assertClose( + tex_init._num_faces_per_mesh[i], + new_tex._num_faces_per_mesh[i * N + n], + ) + + self.assertAllSeparate( + [tex_init.verts_features_padded(), new_tex.verts_features_padded()] + ) + + with self.assertRaises(ValueError): + tex_mesh.extend(N=-1) + + def test_padded_to_packed(self): + # Case where each face in the mesh has 3 unique uv vertex indices + # - i.e. even if a vertex is shared between multiple faces it will + # have a unique uv coordinate for each face. + num_verts_per_mesh = [9, 6] + D = 10 + verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh] + verts_features_packed = list_to_packed(verts_features_list)[0] + verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh) + tex = TexturesVertex(verts_features=verts_features_list) + + # This is set inside Meshes when textures is passed as an input. + # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. + tex1 = tex.clone() + tex1._num_verts_per_mesh = num_verts_per_mesh + verts_packed = tex1.verts_features_packed() + verts_verts_list = tex1.verts_features_list() + verts_padded = tex1.verts_features_padded() + + for f1, f2 in zip(verts_verts_list, verts_features_list): + self.assertTrue((f1 == f2).all().item()) + + self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D)) + self.assertTrue(verts_padded.shape == (2, 9, D)) + + # Case where num_verts_per_mesh is not set and textures + # are initialized with a padded tensor. + tex2 = TexturesVertex(verts_features=verts_padded) + verts_packed = tex2.verts_features_packed() + verts_list = tex2.verts_features_list() + + # Packed is just flattened padded as num_verts_per_mesh + # has not been provided. + self.assertTrue(verts_packed.shape == (9 * 2, D)) + + for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)): + n = num_verts_per_mesh[i] + self.assertTrue((f1[:n] == f2).all().item()) + + def test_getitem(self): + N = 5 + V = 20 + source = {"verts_features": torch.randn(size=(N, 10, 128))} + tex = TexturesVertex(verts_features=source["verts_features"]) + + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces, textures=tex) + + tryindex(self, 2, tex, meshes, source) + tryindex(self, slice(0, 2, 1), tex, meshes, source) + index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([1, 2], dtype=torch.int64) + tryindex(self, index, tex, meshes, source) + tryindex(self, [2, 4], tex, meshes, source) + + +class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): + def test_sample_texture_atlas(self): + N, F, R = 1, 2, 2 + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + faces_atlas = torch.rand(size=(N, F, R, R, 3)) + tex = TexturesAtlas(atlas=faces_atlas) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + expected_vals = torch.tensor( + [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32 + ) + expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32) + expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...] + expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...] + + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + texels = mesh.textures.sample_textures(fragments) + self.assertTrue(torch.allclose(texels, expected_vals)) + + def test_textures_atlas_grad(self): + N, F, R = 1, 2, 2 + verts = torch.randn((4, 3), dtype=torch.float32) + faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64) + faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True) + tex = TexturesAtlas(atlas=faces_atlas) + mesh = Meshes(verts=[verts], faces=[faces], textures=tex) + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + texels = mesh.textures.sample_textures(fragments) + grad_tex = torch.rand_like(texels) + grad_expected = torch.zeros_like(faces_atlas) + grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :] + grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :] + texels.backward(grad_tex) + self.assertTrue(hasattr(faces_atlas, "grad")) + self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected)) + + def test_textures_atlas_init_fail(self): + # Incorrect sized tensors + with self.assertRaisesRegex(ValueError, "atlas"): + TexturesAtlas(atlas=torch.rand(size=(5, 10, 3))) + + # Not a list or a tensor + with self.assertRaisesRegex(ValueError, "atlas"): + TexturesAtlas(atlas=(1, 1, 1)) + + def test_clone(self): + tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3))) + tex_cloned = tex.clone() + self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded) + self.assertSeparate(tex.valid, tex_cloned.valid) + + def test_extend(self): + B = 10 + mesh = TestMeshes.init_mesh(B, 30, 50) + F = mesh._F + tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3))) + tex_mesh = Meshes( + verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv + ) + N = 20 + new_mesh = tex_mesh.extend(N) + + self.assertEqual(len(tex_mesh) * N, len(new_mesh)) + + tex_init = tex_mesh.textures + new_tex = new_mesh.textures + + for i in range(len(tex_mesh)): + for n in range(N): + self.assertClose( + tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n] + ) + self.assertClose( + tex_init._num_faces_per_mesh[i], + new_tex._num_faces_per_mesh[i * N + n], + ) + + self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()]) + + with self.assertRaises(ValueError): + tex_mesh.extend(N=-1) + + def test_padded_to_packed(self): + # Case where each face in the mesh has 3 unique uv vertex indices + # - i.e. even if a vertex is shared between multiple faces it will + # have a unique uv coordinate for each face. + R = 2 + N = 20 + num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30) + atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh] + tex = TexturesAtlas(atlas=atlas_list) + + # This is set inside Meshes when textures is passed as an input. + # Here we set _num_faces_per_mesh explicity. + tex1 = tex.clone() + tex1._num_faces_per_mesh = num_faces_per_mesh.tolist() + atlas_packed = tex1.atlas_packed() + atlas_list_new = tex1.atlas_list() + atlas_padded = tex1.atlas_padded() + + for f1, f2 in zip(atlas_list_new, atlas_list): + self.assertTrue((f1 == f2).all().item()) + + sum_F = num_faces_per_mesh.sum() + max_F = num_faces_per_mesh.max().item() + self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3)) + self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3)) + + # Case where num_faces_per_mesh is not set and textures + # are initialized with a padded tensor. + atlas_list_padded = _list_to_padded_wrapper(atlas_list) + tex2 = TexturesAtlas(atlas=atlas_list_padded) + atlas_packed = tex2.atlas_packed() + atlas_list_new = tex2.atlas_list() + + # Packed is just flattened padded as num_faces_per_mesh + # has not been provided. + self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3)) + + for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)): + n = num_faces_per_mesh[i] + self.assertTrue((f1[:n] == f2).all().item()) + + def test_getitem(self): + N = 5 + V = 20 + source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))} + tex = TexturesAtlas(atlas=source["atlas"]) + + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces, textures=tex) + + tryindex(self, 2, tex, meshes, source) + tryindex(self, slice(0, 2, 1), tex, meshes, source) + index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([1, 2], dtype=torch.int64) + tryindex(self, index, tex, meshes, source) + tryindex(self, [2, 4], tex, meshes, source) + + +class TestTexturesUV(TestCaseMixin, unittest.TestCase): + def test_sample_textures_uv(self): barycentric_coords = torch.tensor( [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32 ).view(1, 1, 1, 2, -1) @@ -38,11 +400,11 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): zbuf=pix_to_face, dists=pix_to_face, ) - tex = Textures( - maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...] - ) + + tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs]) meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex) - texels = interpolate_texture_map(fragments, meshes) + mesh_textures = meshes.textures + texels = mesh_textures.sample_textures(fragments) # Expected output pixel_uvs = interpolated_uvs * 2.0 - 1.0 @@ -53,190 +415,92 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False) self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze())) - def test_init_rgb_uv_fail(self): - V = 20 + def test_textures_uv_init_fail(self): # Maps has wrong shape with self.assertRaisesRegex(ValueError, "maps"): - Textures( + TexturesUV( maps=torch.ones((5, 16, 16, 3, 4)), - faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), - verts_uvs=torch.ones((5, V, 2)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), ) + # faces_uvs has wrong shape with self.assertRaisesRegex(ValueError, "faces_uvs"): - Textures( + TexturesUV( maps=torch.ones((5, 16, 16, 3)), - faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V), - verts_uvs=torch.ones((5, V, 2)), + faces_uvs=torch.rand(size=(5, 10, 3, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), ) + # verts_uvs has wrong shape with self.assertRaisesRegex(ValueError, "verts_uvs"): - Textures( + TexturesUV( maps=torch.ones((5, 16, 16, 3)), - faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), - verts_uvs=torch.ones((5, V, 2, 3)), - ) - # verts_rgb has wrong shape - with self.assertRaisesRegex(ValueError, "verts_rgb"): - Textures(verts_rgb=torch.ones((5, 16, 16, 3))) - - # maps provided without verts/faces uvs - with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"): - Textures(maps=torch.ones((5, 16, 16, 3))) - - def test_padded_to_packed(self): - N = 2 - # Case where each face in the mesh has 3 unique uv vertex indices - # - i.e. even if a vertex is shared between multiple faces it will - # have a unique uv coordinate for each face. - faces_uvs_list = [ - torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]), - torch.tensor([[0, 1, 2], [3, 4, 5]]), - ] # (N, 3, 3) - verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)] - faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1) - verts_uvs_padded = list_to_padded(verts_uvs_list) - tex = Textures( - maps=torch.ones((N, 16, 16, 3)), - faces_uvs=faces_uvs_padded, - verts_uvs=verts_uvs_padded, - ) - - # This is set inside Meshes when textures is passed as an input. - # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. - tex1 = tex.clone() - tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist() - tex1._num_verts_per_mesh = torch.tensor([5, 4]) - faces_packed = tex1.faces_uvs_packed() - verts_packed = tex1.verts_uvs_packed() - faces_list = tex1.faces_uvs_list() - verts_list = tex1.verts_uvs_list() - - for f1, f2 in zip(faces_uvs_list, faces_list): - self.assertTrue((f1 == f2).all().item()) - - for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list): - idx = f.unique() - self.assertTrue((v1[idx] == v2).all().item()) - - self.assertTrue(faces_packed.shape == (3 + 2, 3)) - - # verts_packed is just flattened verts_padded. - # split sizes are not used for verts_uvs. - self.assertTrue(verts_packed.shape == (9 * 2, 2)) - - # Case where num_faces_per_mesh is not set - tex2 = tex.clone() - faces_packed = tex2.faces_uvs_packed() - verts_packed = tex2.verts_uvs_packed() - faces_list = tex2.faces_uvs_list() - verts_list = tex2.verts_uvs_list() - - # Packed is just flattened padded as num_faces_per_mesh - # has not been provided. - self.assertTrue(verts_packed.shape == (9 * 2, 2)) - self.assertTrue(faces_packed.shape == (3 * 2, 3)) - - for i in range(N): - self.assertTrue( - (faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item() + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2, 3)), ) - for i in range(N): - self.assertTrue( - (verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item() + # verts has different batch dim to faces + with self.assertRaisesRegex(ValueError, "verts_uvs"): + TexturesUV( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(8, 15, 2)), + ) + + # maps has different batch dim to faces + with self.assertRaisesRegex(ValueError, "maps"): + TexturesUV( + maps=torch.ones((8, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + ) + + # verts on different device to faces + with self.assertRaisesRegex(ValueError, "verts_uvs"): + TexturesUV( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"), + ) + + # maps on different device to faces + with self.assertRaisesRegex(ValueError, "map"): + TexturesUV( + maps=torch.ones((5, 16, 16, 3), device="cuda"), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), ) def test_clone(self): - V = 20 - tex = Textures( + tex = TexturesUV( maps=torch.ones((5, 16, 16, 3)), - faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), - verts_uvs=torch.ones((5, V, 2)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), ) tex_cloned = tex.clone() self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) - - def test_getitem(self): - N = 5 - V = 20 - source = { - "maps": torch.rand(size=(N, 16, 16, 3)), - "faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V), - "verts_uvs": torch.rand((N, V, 2)), - } - tex = Textures( - maps=source["maps"], - faces_uvs=source["faces_uvs"], - verts_uvs=source["verts_uvs"], - ) - - verts = torch.rand(size=(N, V, 3)) - faces = torch.randint(size=(N, 10, 3), high=V) - meshes = Meshes(verts=verts, faces=faces, textures=tex) - - def tryindex(index): - tex2 = tex[index] - meshes2 = meshes[index] - tex_from_meshes = meshes2.textures - for item in source: - basic = source[item][index] - from_texture = getattr(tex2, item + "_padded")() - from_meshes = getattr(tex_from_meshes, item + "_padded")() - if isinstance(index, int): - basic = basic[None] - self.assertClose(basic, from_texture) - self.assertClose(basic, from_meshes) - self.assertEqual( - from_texture.ndim, getattr(tex, item + "_padded")().ndim - ) - if item == "faces_uvs": - faces_uvs_list = tex_from_meshes.faces_uvs_list() - self.assertEqual(basic.shape[0], len(faces_uvs_list)) - for i, faces_uvs in enumerate(faces_uvs_list): - self.assertClose(faces_uvs, basic[i]) - - tryindex(2) - tryindex(slice(0, 2, 1)) - index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool) - tryindex(index) - index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool) - tryindex(index) - index = torch.tensor([1, 2], dtype=torch.int64) - tryindex(index) - tryindex([2, 4]) - - def test_to(self): - V = 20 - tex = Textures( - maps=torch.ones((5, 16, 16, 3)), - faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), - verts_uvs=torch.ones((5, V, 2)), - ) - device = torch.device("cuda:0") - tex = tex.to(device) - self.assertTrue(tex._faces_uvs_padded.device == device) - self.assertTrue(tex._verts_uvs_padded.device == device) - self.assertTrue(tex._maps_padded.device == device) + self.assertSeparate(tex.valid, tex_cloned.valid) def test_extend(self): - B = 10 + B = 5 mesh = TestMeshes.init_mesh(B, 30, 50) V = mesh._V - F = mesh._F - - # 1. Texture uvs - tex_uv = Textures( - maps=torch.randn((B, 16, 16, 3)), - faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V), - verts_uvs=torch.randn((B, V, 2)), + num_faces = mesh.num_faces_per_mesh() + num_verts = mesh.num_verts_per_mesh() + faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces] + verts_uvs_list = [torch.rand(v, 2) for v in num_verts] + tex_uv = TexturesUV( + maps=torch.ones((B, 16, 16, 3)), + faces_uvs=faces_uvs_list, + verts_uvs=verts_uvs_list, ) tex_mesh = Meshes( - verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv + verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv ) - N = 20 + N = 2 new_mesh = tex_mesh.extend(N) self.assertEqual(len(tex_mesh) * N, len(new_mesh)) @@ -246,56 +510,142 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): for i in range(len(tex_mesh)): for n in range(N): + self.assertClose( + tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n] + ) self.assertClose( tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n] ) self.assertClose( - tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n] + tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n] ) + self.assertClose( + tex_init._num_faces_per_mesh[i], + new_tex._num_faces_per_mesh[i * N + n], + ) + self.assertAllSeparate( [ tex_init.faces_uvs_padded(), new_tex.faces_uvs_padded(), + tex_init.faces_uvs_packed(), + new_tex.faces_uvs_packed(), tex_init.verts_uvs_padded(), new_tex.verts_uvs_padded(), + tex_init.verts_uvs_packed(), + new_tex.verts_uvs_packed(), tex_init.maps_padded(), new_tex.maps_padded(), ] ) - self.assertIsNone(new_tex.verts_rgb_list()) - self.assertIsNone(new_tex.verts_rgb_padded()) - self.assertIsNone(new_tex.verts_rgb_packed()) - - # 2. Texture vertex RGB - tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3))) - tex_mesh_rgb = Meshes( - verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb - ) - N = 20 - new_mesh_rgb = tex_mesh_rgb.extend(N) - - self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb)) - - tex_init = tex_mesh_rgb.textures - new_tex = new_mesh_rgb.textures - - for i in range(len(tex_mesh_rgb)): - for n in range(N): - self.assertClose( - tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n] - ) - self.assertAllSeparate( - [tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()] - ) - - self.assertIsNone(new_tex.verts_uvs_padded()) - self.assertIsNone(new_tex.verts_uvs_list()) - self.assertIsNone(new_tex.verts_uvs_packed()) - self.assertIsNone(new_tex.faces_uvs_padded()) - self.assertIsNone(new_tex.faces_uvs_list()) - self.assertIsNone(new_tex.faces_uvs_packed()) - - # 3. Error with self.assertRaises(ValueError): tex_mesh.extend(N=-1) + + def test_padded_to_packed(self): + # Case where each face in the mesh has 3 unique uv vertex indices + # - i.e. even if a vertex is shared between multiple faces it will + # have a unique uv coordinate for each face. + N = 2 + faces_uvs_list = [ + torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]), + torch.tensor([[0, 1, 2], [3, 4, 5]]), + ] # (N, 3, 3) + verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)] + + num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list] + num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list] + tex = TexturesUV( + maps=torch.ones((N, 16, 16, 3)), + faces_uvs=faces_uvs_list, + verts_uvs=verts_uvs_list, + ) + + # This is set inside Meshes when textures is passed as an input. + # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. + tex1 = tex.clone() + tex1._num_faces_per_mesh = num_faces_per_mesh + tex1._num_verts_per_mesh = num_verts_per_mesh + verts_packed = tex1.verts_uvs_packed() + verts_list = tex1.verts_uvs_list() + verts_padded = tex1.verts_uvs_padded() + + faces_packed = tex1.faces_uvs_packed() + faces_list = tex1.faces_uvs_list() + faces_padded = tex1.faces_uvs_padded() + + for f1, f2 in zip(faces_list, faces_uvs_list): + self.assertTrue((f1 == f2).all().item()) + + for f1, f2 in zip(verts_list, verts_uvs_list): + self.assertTrue((f1 == f2).all().item()) + + self.assertTrue(faces_packed.shape == (3 + 2, 3)) + self.assertTrue(faces_padded.shape == (2, 3, 3)) + self.assertTrue(verts_packed.shape == (9 + 6, 2)) + self.assertTrue(verts_padded.shape == (2, 9, 2)) + + # Case where num_faces_per_mesh is not set and faces_verts_uvs + # are initialized with a padded tensor. + tex2 = TexturesUV( + maps=torch.ones((N, 16, 16, 3)), + verts_uvs=verts_padded, + faces_uvs=faces_padded, + ) + faces_packed = tex2.faces_uvs_packed() + faces_list = tex2.faces_uvs_list() + verts_packed = tex2.verts_uvs_packed() + verts_list = tex2.verts_uvs_list() + + # Packed is just flattened padded as num_faces_per_mesh + # has not been provided. + self.assertTrue(faces_packed.shape == (3 * 2, 3)) + self.assertTrue(verts_packed.shape == (9 * 2, 2)) + + for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)): + n = num_faces_per_mesh[i] + self.assertTrue((f1[:n] == f2).all().item()) + + for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)): + n = num_verts_per_mesh[i] + self.assertTrue((f1[:n] == f2).all().item()) + + def test_to(self): + tex = TexturesUV( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.randint(size=(5, 10, 3), high=15), + verts_uvs=torch.rand(size=(5, 15, 2)), + ) + device = torch.device("cuda:0") + tex = tex.to(device) + self.assertTrue(tex._faces_uvs_padded.device == device) + self.assertTrue(tex._verts_uvs_padded.device == device) + self.assertTrue(tex._maps_padded.device == device) + + def test_getitem(self): + N = 5 + V = 20 + source = { + "maps": torch.rand(size=(N, 1, 1, 3)), + "faces_uvs": torch.randint(size=(N, 10, 3), high=V), + "verts_uvs": torch.randn(size=(N, V, 2)), + } + tex = TexturesUV( + maps=source["maps"], + faces_uvs=source["faces_uvs"], + verts_uvs=source["verts_uvs"], + ) + + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces, textures=tex) + + tryindex(self, 2, tex, meshes, source) + tryindex(self, slice(0, 2, 1), tex, meshes, source) + index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool) + tryindex(self, index, tex, meshes, source) + index = torch.tensor([1, 2], dtype=torch.int64) + tryindex(self, index, tex, meshes, source) + tryindex(self, [2, 4], tex, meshes, source)