mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Texturing API updates
Summary: A fairly big refactor of the texturing API with some breaking changes to how textures are defined. Main changes: - There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class: - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub. - has a `join_batch` method for joining multiple textures of the same type into a batch Reviewed By: gkioxari Differential Revision: D21067427 fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
		
							parent
							
								
									b73d3d6ed9
								
							
						
					
					
						commit
						a3932960b3
					
				@ -520,6 +520,9 @@
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "accelerator": "GPU",
 | 
			
		||||
  "anp_metadata": {
 | 
			
		||||
   "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
 | 
			
		||||
  },
 | 
			
		||||
  "bento_stylesheets": {
 | 
			
		||||
   "bento/extensions/flow/main.css": true,
 | 
			
		||||
   "bento/extensions/kernel_selector/main.css": true,
 | 
			
		||||
@ -533,6 +536,9 @@
 | 
			
		||||
   "provenance": [],
 | 
			
		||||
   "toc_visible": true
 | 
			
		||||
  },
 | 
			
		||||
  "disseminate_notebook_info": {
 | 
			
		||||
   "backup_notebook_id": "1062179640844868"
 | 
			
		||||
  },
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "pytorch3d (local)",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
 | 
			
		||||
@ -84,7 +84,7 @@
 | 
			
		||||
    "from skimage.io import imread\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Util function for loading meshes\n",
 | 
			
		||||
    "from pytorch3d.io import load_objs_as_meshes\n",
 | 
			
		||||
    "from pytorch3d.io import load_objs_as_meshes, load_obj\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# Data structures and functions for rendering\n",
 | 
			
		||||
    "from pytorch3d.structures import Meshes, Textures\n",
 | 
			
		||||
@ -97,7 +97,7 @@
 | 
			
		||||
    "    RasterizationSettings, \n",
 | 
			
		||||
    "    MeshRenderer, \n",
 | 
			
		||||
    "    MeshRasterizer,  \n",
 | 
			
		||||
    "    TexturedSoftPhongShader\n",
 | 
			
		||||
    "    SoftPhongShader\n",
 | 
			
		||||
    ")\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "# add path for demo utils functions \n",
 | 
			
		||||
@ -316,7 +316,7 @@
 | 
			
		||||
    "        cameras=cameras, \n",
 | 
			
		||||
    "        raster_settings=raster_settings\n",
 | 
			
		||||
    "    ),\n",
 | 
			
		||||
    "    shader=TexturedSoftPhongShader(\n",
 | 
			
		||||
    "    shader=SoftPhongShader(\n",
 | 
			
		||||
    "        device=device, \n",
 | 
			
		||||
    "        cameras=cameras,\n",
 | 
			
		||||
    "        lights=lights\n",
 | 
			
		||||
@ -563,6 +563,9 @@
 | 
			
		||||
 ],
 | 
			
		||||
 "metadata": {
 | 
			
		||||
  "accelerator": "GPU",
 | 
			
		||||
  "anp_metadata": {
 | 
			
		||||
   "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/render_textured_meshes.ipynb"
 | 
			
		||||
  },
 | 
			
		||||
  "bento_stylesheets": {
 | 
			
		||||
   "bento/extensions/flow/main.css": true,
 | 
			
		||||
   "bento/extensions/kernel_selector/main.css": true,
 | 
			
		||||
@ -575,6 +578,9 @@
 | 
			
		||||
   "name": "render_textured_meshes.ipynb",
 | 
			
		||||
   "provenance": []
 | 
			
		||||
  },
 | 
			
		||||
  "disseminate_notebook_info": {
 | 
			
		||||
   "backup_notebook_id": "569222367081034"
 | 
			
		||||
  },
 | 
			
		||||
  "kernelspec": {
 | 
			
		||||
   "display_name": "pytorch3d (local)",
 | 
			
		||||
   "language": "python",
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,8 @@ from pytorch3d.renderer import (
 | 
			
		||||
    OpenGLPerspectiveCameras,
 | 
			
		||||
    PointLights,
 | 
			
		||||
    RasterizationSettings,
 | 
			
		||||
    TexturesVertex,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Textures
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
@ -113,8 +113,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
 | 
			
		||||
        """
 | 
			
		||||
        paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
 | 
			
		||||
        meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
 | 
			
		||||
        meshes.textures = Textures(
 | 
			
		||||
            verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
 | 
			
		||||
        meshes.textures = TexturesVertex(
 | 
			
		||||
            verts_features=torch.ones_like(meshes.verts_padded(), device=device)
 | 
			
		||||
        )
 | 
			
		||||
        cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,8 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
 | 
			
		||||
from pytorch3d.io.utils import _open_file
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
 | 
			
		||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
 | 
			
		||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
 | 
			
		||||
@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
 | 
			
		||||
    Args:
 | 
			
		||||
        faces_indices: List of ints of indices.
 | 
			
		||||
        max_index: Max index for the face property.
 | 
			
		||||
        pad_value: if any of the face_indices are padded, specify
 | 
			
		||||
            the value of the padding (e.g. -1). This is only used
 | 
			
		||||
            for texture indices indices where there might
 | 
			
		||||
            not be texture information for all the faces.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        faces_indices: List of ints of indices.
 | 
			
		||||
@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
 | 
			
		||||
        faces_indices[mask] = pad_value
 | 
			
		||||
 | 
			
		||||
    # Check indices are valid.
 | 
			
		||||
    if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
 | 
			
		||||
    if torch.any(faces_indices >= max_index) or (
 | 
			
		||||
        pad_value is None and torch.any(faces_indices < 0)
 | 
			
		||||
    ):
 | 
			
		||||
        warnings.warn("Faces have invalid indices")
 | 
			
		||||
 | 
			
		||||
    return faces_indices
 | 
			
		||||
@ -227,7 +234,14 @@ def load_obj(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
 | 
			
		||||
def load_objs_as_meshes(
 | 
			
		||||
    files: list,
 | 
			
		||||
    device=None,
 | 
			
		||||
    load_textures: bool = True,
 | 
			
		||||
    create_texture_atlas: bool = False,
 | 
			
		||||
    texture_atlas_size: int = 4,
 | 
			
		||||
    texture_wrap: Optional[str] = "repeat",
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Load meshes from a list of .obj files using the load_obj function, and
 | 
			
		||||
    return them as a Meshes object. This only works for meshes which have a
 | 
			
		||||
@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
 | 
			
		||||
    """
 | 
			
		||||
    mesh_list = []
 | 
			
		||||
    for f_obj in files:
 | 
			
		||||
        # TODO: update this function to support the two texturing options.
 | 
			
		||||
        verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
 | 
			
		||||
        verts = verts.to(device)
 | 
			
		||||
        verts, faces, aux = load_obj(
 | 
			
		||||
            f_obj,
 | 
			
		||||
            load_textures=load_textures,
 | 
			
		||||
            create_texture_atlas=create_texture_atlas,
 | 
			
		||||
            texture_atlas_size=texture_atlas_size,
 | 
			
		||||
            texture_wrap=texture_wrap,
 | 
			
		||||
        )
 | 
			
		||||
        tex = None
 | 
			
		||||
        tex_maps = aux.texture_images
 | 
			
		||||
        if tex_maps is not None and len(tex_maps) > 0:
 | 
			
		||||
            verts_uvs = aux.verts_uvs[None, ...].to(device)  # (1, V, 2)
 | 
			
		||||
            faces_uvs = faces.textures_idx[None, ...].to(device)  # (1, F, 3)
 | 
			
		||||
            image = list(tex_maps.values())[0].to(device)[None]
 | 
			
		||||
            tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
 | 
			
		||||
        if create_texture_atlas:
 | 
			
		||||
            # TexturesAtlas type
 | 
			
		||||
            tex = TexturesAtlas(atlas=[aux.texture_atlas])
 | 
			
		||||
        else:
 | 
			
		||||
            # TexturesUV type
 | 
			
		||||
            tex_maps = aux.texture_images
 | 
			
		||||
            if tex_maps is not None and len(tex_maps) > 0:
 | 
			
		||||
                verts_uvs = aux.verts_uvs.to(device)  # (V, 2)
 | 
			
		||||
                faces_uvs = faces.textures_idx.to(device)  # (F, 3)
 | 
			
		||||
                image = list(tex_maps.values())[0].to(device)[None]
 | 
			
		||||
                tex = TexturesUV(
 | 
			
		||||
                    verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
 | 
			
		||||
        mesh = Meshes(
 | 
			
		||||
            verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex
 | 
			
		||||
        )
 | 
			
		||||
        mesh_list.append(mesh)
 | 
			
		||||
    if len(mesh_list) == 1:
 | 
			
		||||
        return mesh_list[0]
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,11 @@ from .mesh import (
 | 
			
		||||
    SoftGouraudShader,
 | 
			
		||||
    SoftPhongShader,
 | 
			
		||||
    SoftSilhouetteShader,
 | 
			
		||||
    TexturedSoftPhongShader,
 | 
			
		||||
    Textures,
 | 
			
		||||
    TexturesAtlas,
 | 
			
		||||
    TexturesUV,
 | 
			
		||||
    TexturesVertex,
 | 
			
		||||
    gouraud_shading,
 | 
			
		||||
    interpolate_face_attributes,
 | 
			
		||||
    interpolate_texture_map,
 | 
			
		||||
    interpolate_vertex_colors,
 | 
			
		||||
    phong_shading,
 | 
			
		||||
    rasterize_meshes,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,10 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors  # isort:skip
 | 
			
		||||
from .rasterize_meshes import rasterize_meshes
 | 
			
		||||
from .rasterizer import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from .renderer import MeshRenderer
 | 
			
		||||
from .shader import TexturedSoftPhongShader  # DEPRECATED
 | 
			
		||||
from .shader import (
 | 
			
		||||
    HardFlatShader,
 | 
			
		||||
    HardGouraudShader,
 | 
			
		||||
@ -12,10 +12,10 @@ from .shader import (
 | 
			
		||||
    SoftGouraudShader,
 | 
			
		||||
    SoftPhongShader,
 | 
			
		||||
    SoftSilhouetteShader,
 | 
			
		||||
    TexturedSoftPhongShader,
 | 
			
		||||
)
 | 
			
		||||
from .shading import gouraud_shading, phong_shading
 | 
			
		||||
from .utils import interpolate_face_attributes
 | 
			
		||||
from .textures import Textures  # DEPRECATED
 | 
			
		||||
from .textures import TexturesAtlas, TexturesUV, TexturesVertex
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
@ -13,7 +14,6 @@ from ..blending import (
 | 
			
		||||
from ..lighting import PointLights
 | 
			
		||||
from ..materials import Materials
 | 
			
		||||
from .shading import flat_shading, gouraud_shading, phong_shading
 | 
			
		||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A Shader should take as input fragments from the output of rasterization
 | 
			
		||||
@ -57,7 +57,7 @@ class HardPhongShader(nn.Module):
 | 
			
		||||
                or in the forward pass of HardPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
@ -104,9 +104,11 @@ class SoftPhongShader(nn.Module):
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
        colors = phong_shading(
 | 
			
		||||
            meshes=meshes,
 | 
			
		||||
            fragments=fragments,
 | 
			
		||||
@ -115,7 +117,7 @@ class SoftPhongShader(nn.Module):
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            materials=materials,
 | 
			
		||||
        )
 | 
			
		||||
        images = softmax_rgb_blend(colors, fragments, self.blend_params)
 | 
			
		||||
        images = softmax_rgb_blend(colors, fragments, blend_params)
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -154,6 +156,12 @@ class HardGouraudShader(nn.Module):
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
 | 
			
		||||
        # As Gouraud shading applies the illumination to the vertex
 | 
			
		||||
        # colors, the interpolated pixel texture is calculated in the
 | 
			
		||||
        # shading step. In comparison, for Phong shading, the pixel
 | 
			
		||||
        # textures are computed first after which the illumination is
 | 
			
		||||
        # applied.
 | 
			
		||||
        pixel_colors = gouraud_shading(
 | 
			
		||||
            meshes=meshes,
 | 
			
		||||
            fragments=fragments,
 | 
			
		||||
@ -210,54 +218,25 @@ class SoftGouraudShader(nn.Module):
 | 
			
		||||
        return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TexturedSoftPhongShader(nn.Module):
 | 
			
		||||
def TexturedSoftPhongShader(
 | 
			
		||||
    device="cpu", cameras=None, lights=None, materials=None, blend_params=None
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Per pixel lighting applied to a texture map. First interpolate the vertex
 | 
			
		||||
    uv coordinates and sample from a texture map. Then apply the lighting model
 | 
			
		||||
    using the interpolated coords and normals for each pixel.
 | 
			
		||||
 | 
			
		||||
    The blending function returns the soft aggregated color using all
 | 
			
		||||
    the faces per pixel.
 | 
			
		||||
 | 
			
		||||
    To use the default values, simply initialize the shader with the desired
 | 
			
		||||
    device e.g.
 | 
			
		||||
 | 
			
		||||
    .. code-block::
 | 
			
		||||
 | 
			
		||||
        shader = TexturedPhongShader(device=torch.device("cuda:0"))
 | 
			
		||||
    TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
 | 
			
		||||
    Preserving TexturedSoftPhongShader as a function for backwards compatibility.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.lights = lights if lights is not None else PointLights(device=device)
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of TexturedSoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_texture_map(fragments, meshes)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
        colors = phong_shading(
 | 
			
		||||
            meshes=meshes,
 | 
			
		||||
            fragments=fragments,
 | 
			
		||||
            texels=texels,
 | 
			
		||||
            lights=lights,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            materials=materials,
 | 
			
		||||
        )
 | 
			
		||||
        images = softmax_rgb_blend(colors, fragments, blend_params)
 | 
			
		||||
        return images
 | 
			
		||||
    warnings.warn(
 | 
			
		||||
        """TexturedSoftPhongShader is now deprecated;
 | 
			
		||||
            use SoftPhongShader instead.""",
 | 
			
		||||
        PendingDeprecationWarning,
 | 
			
		||||
    )
 | 
			
		||||
    return SoftPhongShader(
 | 
			
		||||
        device=device,
 | 
			
		||||
        cameras=cameras,
 | 
			
		||||
        lights=lights,
 | 
			
		||||
        materials=materials,
 | 
			
		||||
        blend_params=blend_params,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HardFlatShader(nn.Module):
 | 
			
		||||
@ -291,7 +270,7 @@ class HardFlatShader(nn.Module):
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of HardFlatShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        texels = meshes.sample_textures(fragments)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,8 @@ from typing import Tuple
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.ops import interpolate_face_attributes
 | 
			
		||||
 | 
			
		||||
from .textures import TexturesVertex
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_lighting(
 | 
			
		||||
    points, normals, lights, cameras, materials
 | 
			
		||||
@ -91,6 +93,9 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
 | 
			
		||||
    Then interpolate the vertex shaded colors using the barycentric coordinates
 | 
			
		||||
    to get a color per pixel.
 | 
			
		||||
 | 
			
		||||
    Gouraud shading is only supported for meshes with texture type `TexturesVertex`.
 | 
			
		||||
    This is because the illumination is applied to the vertex colors.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        meshes: Batch of meshes
 | 
			
		||||
        fragments: Fragments named tuple with the outputs of rasterization
 | 
			
		||||
@ -101,10 +106,13 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
 | 
			
		||||
    Returns:
 | 
			
		||||
        colors: (N, H, W, K, 3)
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(meshes.textures, TexturesVertex):
 | 
			
		||||
        raise ValueError("Mesh textures must be an instance of TexturesVertex")
 | 
			
		||||
 | 
			
		||||
    faces = meshes.faces_packed()  # (F, 3)
 | 
			
		||||
    verts = meshes.verts_packed()
 | 
			
		||||
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
 | 
			
		||||
    vertex_colors = meshes.textures.verts_rgb_packed()
 | 
			
		||||
    verts = meshes.verts_packed()  # (V, 3)
 | 
			
		||||
    verts_normals = meshes.verts_normals_packed()  # (V, 3)
 | 
			
		||||
    verts_colors = meshes.textures.verts_features_packed()  # (V, D)
 | 
			
		||||
    vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()
 | 
			
		||||
 | 
			
		||||
    # Format properties of lights and materials so they are compatible
 | 
			
		||||
@ -119,9 +127,10 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
 | 
			
		||||
 | 
			
		||||
    # Calculate the illumination at each vertex
 | 
			
		||||
    ambient, diffuse, specular = _apply_lighting(
 | 
			
		||||
        verts, vertex_normals, lights, cameras, materials
 | 
			
		||||
        verts, verts_normals, lights, cameras, materials
 | 
			
		||||
    )
 | 
			
		||||
    verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
 | 
			
		||||
 | 
			
		||||
    verts_colors_shaded = verts_colors * (ambient + diffuse) + specular
 | 
			
		||||
    face_colors = verts_colors_shaded[faces]
 | 
			
		||||
    colors = interpolate_face_attributes(
 | 
			
		||||
        fragments.pix_to_face, fragments.bary_coords, face_colors
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1049
									
								
								pytorch3d/renderer/mesh/textures.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1049
									
								
								pytorch3d/renderer/mesh/textures.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,113 +0,0 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from pytorch3d.ops import interpolate_face_attributes
 | 
			
		||||
from pytorch3d.structures.textures import Textures
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Interpolate a 2D texture map using uv vertex texture coordinates for each
 | 
			
		||||
    face in the mesh. First interpolate the vertex uvs using barycentric coordinates
 | 
			
		||||
    for each pixel in the rasterized output. Then interpolate the texture map
 | 
			
		||||
    using the uv coordinate for each pixel.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        fragments:
 | 
			
		||||
            The outputs of rasterization. From this we use
 | 
			
		||||
 | 
			
		||||
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
 | 
			
		||||
              of the faces (in the packed representation) which
 | 
			
		||||
              overlap each pixel in the image.
 | 
			
		||||
            - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
 | 
			
		||||
              the barycentric coordianates of each pixel
 | 
			
		||||
              relative to the faces (in the packed
 | 
			
		||||
              representation) which overlap the pixel.
 | 
			
		||||
        meshes: Meshes representing a batch of meshes. It is expected that
 | 
			
		||||
            meshes has a textures attribute which is an instance of the
 | 
			
		||||
            Textures class.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        texels: tensor of shape (N, H, W, K, C) giving the interpolated
 | 
			
		||||
        texture for each pixel in the rasterized image.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(meshes.textures, Textures):
 | 
			
		||||
        msg = "Expected meshes.textures to be an instance of Textures; got %r"
 | 
			
		||||
        raise ValueError(msg % type(meshes.textures))
 | 
			
		||||
 | 
			
		||||
    faces_uvs = meshes.textures.faces_uvs_packed()
 | 
			
		||||
    verts_uvs = meshes.textures.verts_uvs_packed()
 | 
			
		||||
    faces_verts_uvs = verts_uvs[faces_uvs]
 | 
			
		||||
    texture_maps = meshes.textures.maps_padded()
 | 
			
		||||
 | 
			
		||||
    # pixel_uvs: (N, H, W, K, 2)
 | 
			
		||||
    pixel_uvs = interpolate_face_attributes(
 | 
			
		||||
        fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    N, H_out, W_out, K = fragments.pix_to_face.shape
 | 
			
		||||
    N, H_in, W_in, C = texture_maps.shape  # 3 for RGB
 | 
			
		||||
 | 
			
		||||
    # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
 | 
			
		||||
    pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
 | 
			
		||||
 | 
			
		||||
    # textures.map:
 | 
			
		||||
    #   (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
 | 
			
		||||
    #   -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
 | 
			
		||||
    texture_maps = (
 | 
			
		||||
        texture_maps.permute(0, 3, 1, 2)[None, ...]
 | 
			
		||||
        .expand(K, -1, -1, -1, -1)
 | 
			
		||||
        .transpose(0, 1)
 | 
			
		||||
        .reshape(N * K, C, H_in, W_in)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
 | 
			
		||||
    # Now need to format the pixel uvs and the texture map correctly!
 | 
			
		||||
    # From pytorch docs, grid_sample takes `grid` and `input`:
 | 
			
		||||
    #   grid specifies the sampling pixel locations normalized by
 | 
			
		||||
    #   the input spatial dimensions It should have most
 | 
			
		||||
    #   values in the range of [-1, 1]. Values x = -1, y = -1
 | 
			
		||||
    #   is the left-top pixel of input, and values x = 1, y = 1 is the
 | 
			
		||||
    #   right-bottom pixel of input.
 | 
			
		||||
 | 
			
		||||
    pixel_uvs = pixel_uvs * 2.0 - 1.0
 | 
			
		||||
    texture_maps = torch.flip(texture_maps, [2])  # flip y axis of the texture map
 | 
			
		||||
    if texture_maps.device != pixel_uvs.device:
 | 
			
		||||
        texture_maps = texture_maps.to(pixel_uvs.device)
 | 
			
		||||
    texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
 | 
			
		||||
    texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
 | 
			
		||||
    return texels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Detemine the color for each rasterized face. Interpolate the colors for
 | 
			
		||||
    vertices which form the face using the barycentric coordinates.
 | 
			
		||||
    Args:
 | 
			
		||||
        meshes: A Meshes class representing a batch of meshes.
 | 
			
		||||
        fragments:
 | 
			
		||||
            The outputs of rasterization. From this we use
 | 
			
		||||
 | 
			
		||||
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
 | 
			
		||||
              of the faces (in the packed representation) which
 | 
			
		||||
              overlap each pixel in the image.
 | 
			
		||||
            - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
 | 
			
		||||
              the barycentric coordianates of each pixel
 | 
			
		||||
              relative to the faces (in the packed
 | 
			
		||||
              representation) which overlap the pixel.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        texels: An texture per pixel of shape (N, H, W, K, C).
 | 
			
		||||
        There will be one C dimensional value for each element in
 | 
			
		||||
        fragments.pix_to_face.
 | 
			
		||||
    """
 | 
			
		||||
    vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3)  # (V, C)
 | 
			
		||||
    vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
 | 
			
		||||
    faces_packed = meshes.faces_packed()
 | 
			
		||||
    faces_textures = vertex_textures[faces_packed]  # (F, 3, C)
 | 
			
		||||
    texels = interpolate_face_attributes(
 | 
			
		||||
        fragments.pix_to_face, fragments.bary_coords, faces_textures
 | 
			
		||||
    )
 | 
			
		||||
    return texels
 | 
			
		||||
@ -2,7 +2,6 @@
 | 
			
		||||
 | 
			
		||||
from .meshes import Meshes, join_meshes_as_batch
 | 
			
		||||
from .pointclouds import Pointclouds
 | 
			
		||||
from .textures import Textures
 | 
			
		||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,6 @@ from typing import List, Union
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from . import utils as struct_utils
 | 
			
		||||
from .textures import Textures
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Meshes(object):
 | 
			
		||||
@ -234,9 +233,9 @@ class Meshes(object):
 | 
			
		||||
        Refer to comments above for descriptions of List and Padded representations.
 | 
			
		||||
        """
 | 
			
		||||
        self.device = None
 | 
			
		||||
        if textures is not None and not isinstance(textures, Textures):
 | 
			
		||||
            msg = "Expected textures to be of type Textures; got %r"
 | 
			
		||||
            raise ValueError(msg % type(textures))
 | 
			
		||||
        if textures is not None and not repr(textures) == "TexturesBase":
 | 
			
		||||
            msg = "Expected textures to be an instance of type TexturesBase; got %r"
 | 
			
		||||
            raise ValueError(msg % repr(textures))
 | 
			
		||||
        self.textures = textures
 | 
			
		||||
 | 
			
		||||
        # Indicates whether the meshes in the list/batch have the same number
 | 
			
		||||
@ -400,6 +399,8 @@ class Meshes(object):
 | 
			
		||||
        if self.textures is not None:
 | 
			
		||||
            self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
 | 
			
		||||
            self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
 | 
			
		||||
            self.textures._N = self._N
 | 
			
		||||
            self.textures.valid = self.valid
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self._N
 | 
			
		||||
@ -1465,6 +1466,17 @@ class Meshes(object):
 | 
			
		||||
 | 
			
		||||
        return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)
 | 
			
		||||
 | 
			
		||||
    def sample_textures(self, fragments):
 | 
			
		||||
        if self.textures is not None:
 | 
			
		||||
            # Pass in faces packed. If the textures are defined per
 | 
			
		||||
            # vertex, the face indices are needed in order to interpolate
 | 
			
		||||
            # the vertex attributes across the face.
 | 
			
		||||
            return self.textures.sample_textures(
 | 
			
		||||
                fragments, faces_packed=self.faces_packed()
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Meshes does not have textures")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
 | 
			
		||||
    """
 | 
			
		||||
@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
 | 
			
		||||
        raise ValueError("Inconsistent textures in join_meshes_as_batch.")
 | 
			
		||||
 | 
			
		||||
    # Now we know there are multiple meshes and they have textures to merge.
 | 
			
		||||
    first = meshes[0].textures
 | 
			
		||||
    kwargs = {}
 | 
			
		||||
    if first.maps_padded() is not None:
 | 
			
		||||
        if any(mesh.textures.maps_padded() is None for mesh in meshes):
 | 
			
		||||
            raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
 | 
			
		||||
        maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
 | 
			
		||||
        kwargs["maps"] = maps
 | 
			
		||||
    elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
 | 
			
		||||
        raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
 | 
			
		||||
    all_textures = [mesh.textures for mesh in meshes]
 | 
			
		||||
    first = all_textures[0]
 | 
			
		||||
    tex_types_same = all(type(tex) == type(first) for tex in all_textures)
 | 
			
		||||
 | 
			
		||||
    if first.verts_uvs_padded() is not None:
 | 
			
		||||
        if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
 | 
			
		||||
            raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
 | 
			
		||||
        uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
 | 
			
		||||
        V = max(uv.shape[0] for uv in uvs)
 | 
			
		||||
        kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
 | 
			
		||||
    elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
 | 
			
		||||
        raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
 | 
			
		||||
    if not tex_types_same:
 | 
			
		||||
        raise ValueError("All meshes in the batch must have the same type of texture.")
 | 
			
		||||
 | 
			
		||||
    if first.faces_uvs_padded() is not None:
 | 
			
		||||
        if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
 | 
			
		||||
            raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
 | 
			
		||||
        uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
 | 
			
		||||
        F = max(uv.shape[0] for uv in uvs)
 | 
			
		||||
        kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
 | 
			
		||||
    elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
 | 
			
		||||
        raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
 | 
			
		||||
 | 
			
		||||
    if first.verts_rgb_padded() is not None:
 | 
			
		||||
        if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
 | 
			
		||||
            raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
 | 
			
		||||
        rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
 | 
			
		||||
        V = max(i.shape[0] for i in rgb)
 | 
			
		||||
        kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
 | 
			
		||||
    elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
 | 
			
		||||
        raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
 | 
			
		||||
 | 
			
		||||
    tex = Textures(**kwargs)
 | 
			
		||||
    tex = first.join_batch(all_textures[1:])
 | 
			
		||||
    return Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
 | 
			
		||||
    """
 | 
			
		||||
    Joins a batch of meshes in the form of a Meshes object or a list of Meshes
 | 
			
		||||
    objects as a single mesh. If the input is a list, the Meshes objects in the list
 | 
			
		||||
    must all be on the same device. This version ignores all textures in the input mehses.
 | 
			
		||||
    must all be on the same device. This version ignores all textures in the input meshes.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
 | 
			
		||||
 | 
			
		||||
@ -1,279 +0,0 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from typing import List, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn.functional import interpolate
 | 
			
		||||
 | 
			
		||||
from .utils import padded_to_list, padded_to_packed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
This file has functions for interpolating textures after rasterization.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Pad all texture images so they have the same height and width.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        images: list of N tensors of shape (H, W, 3)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        tex_maps: Tensor of shape (N, max_H, max_W, 3)
 | 
			
		||||
    """
 | 
			
		||||
    tex_maps = []
 | 
			
		||||
    max_H = 0
 | 
			
		||||
    max_W = 0
 | 
			
		||||
    for im in images:
 | 
			
		||||
        h, w, _3 = im.shape
 | 
			
		||||
        if h > max_H:
 | 
			
		||||
            max_H = h
 | 
			
		||||
        if w > max_W:
 | 
			
		||||
            max_W = w
 | 
			
		||||
        tex_maps.append(im)
 | 
			
		||||
    max_shape = (max_H, max_W)
 | 
			
		||||
 | 
			
		||||
    for i, image in enumerate(tex_maps):
 | 
			
		||||
        if image.shape[:2] != max_shape:
 | 
			
		||||
            image_BCHW = image.permute(2, 0, 1)[None]
 | 
			
		||||
            new_image_BCHW = interpolate(
 | 
			
		||||
                image_BCHW, size=max_shape, mode="bilinear", align_corners=False
 | 
			
		||||
            )
 | 
			
		||||
            tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
 | 
			
		||||
    tex_maps = torch.stack(tex_maps, dim=0)  # (num_tex_maps, max_H, max_W, 3)
 | 
			
		||||
    return tex_maps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
 | 
			
		||||
    dimension. This is done in the following sequence of steps (where `B` is
 | 
			
		||||
    the batch dimension):
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        input_tensor (B, ...)
 | 
			
		||||
        -> add leading empty dimension (1, B, ...)
 | 
			
		||||
        -> expand (N, B, ...)
 | 
			
		||||
        -> reshape (N * B, ...)
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        input_tensor: torch.Tensor with ndim > 2 representing a batched input.
 | 
			
		||||
        N: number of times to extend each element of the batch.
 | 
			
		||||
    """
 | 
			
		||||
    # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
 | 
			
		||||
    if input_tensor.ndim < 2:
 | 
			
		||||
        raise ValueError("Input tensor must have ndimensions >= 2.")
 | 
			
		||||
    B = input_tensor.shape[0]
 | 
			
		||||
    non_batch_dims = tuple(input_tensor.shape[1:])
 | 
			
		||||
    constant_dims = (-1,) * input_tensor.ndim  # these dims are not expanded.
 | 
			
		||||
    return (
 | 
			
		||||
        input_tensor.clone()[None, ...]
 | 
			
		||||
        .expand(N, *constant_dims)
 | 
			
		||||
        .transpose(0, 1)
 | 
			
		||||
        .reshape(N * B, *non_batch_dims)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Textures(object):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        maps: Union[List, torch.Tensor, None] = None,
 | 
			
		||||
        faces_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
        verts_uvs: Optional[torch.Tensor] = None,
 | 
			
		||||
        verts_rgb: Optional[torch.Tensor] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            maps: texture map per mesh. This can either be a list of maps
 | 
			
		||||
              [(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
 | 
			
		||||
            faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
 | 
			
		||||
                vertex in the face. Padding value is assumed to be -1.
 | 
			
		||||
            verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
 | 
			
		||||
            verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
 | 
			
		||||
                value is assumed to be -1.
 | 
			
		||||
 | 
			
		||||
        Note: only the padded representation of the textures is stored
 | 
			
		||||
        and the packed/list representations are computed on the fly and
 | 
			
		||||
        not cached.
 | 
			
		||||
        """
 | 
			
		||||
        # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
 | 
			
		||||
        if faces_uvs is not None and faces_uvs.ndim != 3:
 | 
			
		||||
            msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
 | 
			
		||||
            raise ValueError(msg % repr(faces_uvs.shape))
 | 
			
		||||
        if verts_uvs is not None and verts_uvs.ndim != 3:
 | 
			
		||||
            msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
 | 
			
		||||
            raise ValueError(msg % repr(verts_uvs.shape))
 | 
			
		||||
        if verts_rgb is not None and verts_rgb.ndim != 3:
 | 
			
		||||
            msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
 | 
			
		||||
            raise ValueError(msg % repr(verts_rgb.shape))
 | 
			
		||||
        if maps is not None:
 | 
			
		||||
            # pyre-fixme[16]: `List` has no attribute `ndim`.
 | 
			
		||||
            if torch.is_tensor(maps) and maps.ndim != 4:
 | 
			
		||||
                msg = "Expected maps to be of shape (N, H, W, 3); got %r"
 | 
			
		||||
                # pyre-fixme[16]: `List` has no attribute `shape`.
 | 
			
		||||
                raise ValueError(msg % repr(maps.shape))
 | 
			
		||||
            elif isinstance(maps, list):
 | 
			
		||||
                maps = _pad_texture_maps(maps)
 | 
			
		||||
            if faces_uvs is None or verts_uvs is None:
 | 
			
		||||
                msg = "To use maps, faces_uvs and verts_uvs are required"
 | 
			
		||||
                raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        self._faces_uvs_padded = faces_uvs
 | 
			
		||||
        self._verts_uvs_padded = verts_uvs
 | 
			
		||||
        self._verts_rgb_padded = verts_rgb
 | 
			
		||||
        self._maps_padded = maps
 | 
			
		||||
 | 
			
		||||
        # The number of faces/verts for each mesh is
 | 
			
		||||
        # set inside the Meshes object when textures is
 | 
			
		||||
        # passed into the Meshes constructor.
 | 
			
		||||
        self._num_faces_per_mesh = None
 | 
			
		||||
        self._num_verts_per_mesh = None
 | 
			
		||||
 | 
			
		||||
    def clone(self):
 | 
			
		||||
        other = self.__class__()
 | 
			
		||||
        for k in dir(self):
 | 
			
		||||
            v = getattr(self, k)
 | 
			
		||||
            if torch.is_tensor(v):
 | 
			
		||||
                setattr(other, k, v.clone())
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def to(self, device):
 | 
			
		||||
        for k in dir(self):
 | 
			
		||||
            v = getattr(self, k)
 | 
			
		||||
            if torch.is_tensor(v) and v.device != device:
 | 
			
		||||
                setattr(self, k, v.to(device))
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        other = self.__class__()
 | 
			
		||||
        for key in dir(self):
 | 
			
		||||
            value = getattr(self, key)
 | 
			
		||||
            if torch.is_tensor(value):
 | 
			
		||||
                if isinstance(index, int):
 | 
			
		||||
                    setattr(other, key, value[index][None])
 | 
			
		||||
                else:
 | 
			
		||||
                    setattr(other, key, value[index])
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def faces_uvs_padded(self) -> torch.Tensor:
 | 
			
		||||
        # pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`.
 | 
			
		||||
        return self._faces_uvs_padded
 | 
			
		||||
 | 
			
		||||
    def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
 | 
			
		||||
        if self._faces_uvs_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return padded_to_list(
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            self._faces_uvs_padded,
 | 
			
		||||
            split_size=self._num_faces_per_mesh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        if self._faces_uvs_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return padded_to_packed(
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            self._faces_uvs_padded,
 | 
			
		||||
            split_size=self._num_faces_per_mesh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        return self._verts_uvs_padded
 | 
			
		||||
 | 
			
		||||
    def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
 | 
			
		||||
        if self._verts_uvs_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        # Vertices shared between multiple faces
 | 
			
		||||
        # may have a different uv coordinate for
 | 
			
		||||
        # each face so the num_verts_uvs_per_mesh
 | 
			
		||||
        # may be different from num_verts_per_mesh.
 | 
			
		||||
        # Therefore don't use any split_size.
 | 
			
		||||
        # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
        #  `Optional[torch.Tensor]`.
 | 
			
		||||
        return padded_to_list(self._verts_uvs_padded)
 | 
			
		||||
 | 
			
		||||
    def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        if self._verts_uvs_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        # Vertices shared between multiple faces
 | 
			
		||||
        # may have a different uv coordinate for
 | 
			
		||||
        # each face so the num_verts_uvs_per_mesh
 | 
			
		||||
        # may be different from num_verts_per_mesh.
 | 
			
		||||
        # Therefore don't use any split_size.
 | 
			
		||||
        # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
        #  `Optional[torch.Tensor]`.
 | 
			
		||||
        return padded_to_packed(self._verts_uvs_padded)
 | 
			
		||||
 | 
			
		||||
    def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        return self._verts_rgb_padded
 | 
			
		||||
 | 
			
		||||
    def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
 | 
			
		||||
        if self._verts_rgb_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return padded_to_list(
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            self._verts_rgb_padded,
 | 
			
		||||
            split_size=self._num_verts_per_mesh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        if self._verts_rgb_padded is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return padded_to_packed(
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            self._verts_rgb_padded,
 | 
			
		||||
            split_size=self._num_verts_per_mesh,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # Currently only the padded maps are used.
 | 
			
		||||
    def maps_padded(self) -> Union[torch.Tensor, None]:
 | 
			
		||||
        # pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None,
 | 
			
		||||
        #  List[typing.Any], torch.Tensor]`.
 | 
			
		||||
        return self._maps_padded
 | 
			
		||||
 | 
			
		||||
    def extend(self, N: int) -> "Textures":
 | 
			
		||||
        """
 | 
			
		||||
        Create new Textures class which contains each input texture N times
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            N: number of new copies of each texture.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Textures object.
 | 
			
		||||
        """
 | 
			
		||||
        if not isinstance(N, int):
 | 
			
		||||
            raise ValueError("N must be an integer.")
 | 
			
		||||
        if N <= 0:
 | 
			
		||||
            raise ValueError("N must be > 0.")
 | 
			
		||||
 | 
			
		||||
        if all(
 | 
			
		||||
            v is not None
 | 
			
		||||
            for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
 | 
			
		||||
        ):
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None,
 | 
			
		||||
            #  List[typing.Any], torch.Tensor]`.
 | 
			
		||||
            new_maps = _extend_tensor(self._maps_padded, N)
 | 
			
		||||
            return self.__class__(
 | 
			
		||||
                verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
 | 
			
		||||
            )
 | 
			
		||||
        elif self._verts_rgb_padded is not None:
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
 | 
			
		||||
            return self.__class__(verts_rgb=new_verts_rgb)
 | 
			
		||||
        else:
 | 
			
		||||
            msg = "Either vertex colors or texture maps are required."
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
 | 
			
		||||
    # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
 | 
			
		||||
    if x.ndim != 3:
 | 
			
		||||
        raise ValueError("Supports only 3-dimensional input tensors")
 | 
			
		||||
 | 
			
		||||
    x_list = list(x.unbind(0))
 | 
			
		||||
 | 
			
		||||
    if split_size is None:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_texture_atlas_8x8_back.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_texture_atlas_8x8_back.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 31 KiB  | 
@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import (
 | 
			
		||||
    interpolate_face_attributes,
 | 
			
		||||
    interpolate_face_attributes_python,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh import TexturesVertex
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
 | 
			
		||||
from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures
 | 
			
		||||
from pytorch3d.structures import Meshes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
 | 
			
		||||
 | 
			
		||||
    def test_interpolate_attributes(self):
 | 
			
		||||
        """
 | 
			
		||||
        This tests both interpolate_vertex_colors as well as
 | 
			
		||||
        interpolate_face_attributes.
 | 
			
		||||
        """
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        vert_tex = torch.tensor(
 | 
			
		||||
            [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
        tex = Textures(verts_rgb=vert_tex[None, :])
 | 
			
		||||
        tex = TexturesVertex(verts_features=vert_tex[None, :])
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, mesh)
 | 
			
		||||
 | 
			
		||||
        verts_features_packed = mesh.textures.verts_features_packed()
 | 
			
		||||
        faces_verts_features = verts_features_packed[mesh.faces_packed()]
 | 
			
		||||
 | 
			
		||||
        texels = interpolate_face_attributes(
 | 
			
		||||
            fragments.pix_to_face, fragments.bary_coords, faces_verts_features
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
 | 
			
		||||
 | 
			
		||||
    def test_interpolate_attributes_grad(self):
 | 
			
		||||
@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            requires_grad=True,
 | 
			
		||||
        )
 | 
			
		||||
        tex = Textures(verts_rgb=vert_tex[None, :])
 | 
			
		||||
        tex = TexturesVertex(verts_features=vert_tex[None, :])
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, mesh)
 | 
			
		||||
        verts_features_packed = mesh.textures.verts_features_packed()
 | 
			
		||||
        faces_verts_features = verts_features_packed[mesh.faces_packed()]
 | 
			
		||||
 | 
			
		||||
        texels = interpolate_face_attributes(
 | 
			
		||||
            fragments.pix_to_face, fragments.bary_coords, faces_verts_features
 | 
			
		||||
        )
 | 
			
		||||
        texels.sum().backward()
 | 
			
		||||
        self.assertTrue(hasattr(vert_tex, "grad"))
 | 
			
		||||
        self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,8 @@ from pytorch3d.io.mtl_io import (
 | 
			
		||||
    _bilinear_interpolation_grid_sample,
 | 
			
		||||
    _bilinear_interpolation_vectorized,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
 | 
			
		||||
from pytorch3d.structures.meshes import join_mesh
 | 
			
		||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex
 | 
			
		||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
 | 
			
		||||
from pytorch3d.utils import torus
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -590,17 +590,29 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
            check_item(mesh.verts_padded(), mesh3.verts_padded())
 | 
			
		||||
            check_item(mesh.faces_padded(), mesh3.faces_padded())
 | 
			
		||||
 | 
			
		||||
            if mesh.textures is not None:
 | 
			
		||||
                check_item(mesh.textures.maps_padded(), mesh3.textures.maps_padded())
 | 
			
		||||
                check_item(
 | 
			
		||||
                    mesh.textures.faces_uvs_padded(), mesh3.textures.faces_uvs_padded()
 | 
			
		||||
                )
 | 
			
		||||
                check_item(
 | 
			
		||||
                    mesh.textures.verts_uvs_padded(), mesh3.textures.verts_uvs_padded()
 | 
			
		||||
                )
 | 
			
		||||
                check_item(
 | 
			
		||||
                    mesh.textures.verts_rgb_padded(), mesh3.textures.verts_rgb_padded()
 | 
			
		||||
                )
 | 
			
		||||
                if isinstance(mesh.textures, TexturesUV):
 | 
			
		||||
                    check_item(
 | 
			
		||||
                        mesh.textures.faces_uvs_padded(),
 | 
			
		||||
                        mesh3.textures.faces_uvs_padded(),
 | 
			
		||||
                    )
 | 
			
		||||
                    check_item(
 | 
			
		||||
                        mesh.textures.verts_uvs_padded(),
 | 
			
		||||
                        mesh3.textures.verts_uvs_padded(),
 | 
			
		||||
                    )
 | 
			
		||||
                    check_item(
 | 
			
		||||
                        mesh.textures.maps_padded(), mesh3.textures.maps_padded()
 | 
			
		||||
                    )
 | 
			
		||||
                elif isinstance(mesh.textures, TexturesVertex):
 | 
			
		||||
                    check_item(
 | 
			
		||||
                        mesh.textures.verts_features_padded(),
 | 
			
		||||
                        mesh3.textures.verts_features_padded(),
 | 
			
		||||
                    )
 | 
			
		||||
                elif isinstance(mesh.textures, TexturesAtlas):
 | 
			
		||||
                    check_item(
 | 
			
		||||
                        mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        obj_filename = DATA_DIR / "cow_mesh/cow.obj"
 | 
			
		||||
@ -623,16 +635,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        check_triple(mesh_notex, mesh3_notex)
 | 
			
		||||
        self.assertIsNone(mesh_notex.textures)
 | 
			
		||||
 | 
			
		||||
        # meshes with vertex texture, join into a batch.
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        vert_tex = torch.tensor(
 | 
			
		||||
            [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
        tex = Textures(verts_rgb=vert_tex[None, :])
 | 
			
		||||
        mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        vert_tex = torch.ones_like(verts)
 | 
			
		||||
        rgb_tex = TexturesVertex(verts_features=[vert_tex])
 | 
			
		||||
        mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
 | 
			
		||||
        mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
 | 
			
		||||
        check_triple(mesh_rgb, mesh_rgb3)
 | 
			
		||||
 | 
			
		||||
        # meshes with texture atlas, join into a batch.
 | 
			
		||||
        device = "cuda:0"
 | 
			
		||||
        atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device)
 | 
			
		||||
        atlas_tex = TexturesAtlas(atlas=[atlas])
 | 
			
		||||
        mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex)
 | 
			
		||||
        mesh_atlas3 = join_meshes_as_batch([mesh_atlas, mesh_atlas, mesh_atlas])
 | 
			
		||||
        check_triple(mesh_atlas, mesh_atlas3)
 | 
			
		||||
 | 
			
		||||
        # Test load multiple meshes with textures into a batch.
 | 
			
		||||
        teapot_obj = DATA_DIR / "teapot.obj"
 | 
			
		||||
        mesh_teapot = load_objs_as_meshes([teapot_obj])
 | 
			
		||||
        teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
 | 
			
		||||
@ -649,41 +669,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
 | 
			
		||||
        self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
 | 
			
		||||
 | 
			
		||||
    def test_join_meshes(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test that join_mesh joins single meshes and the corresponding values are
 | 
			
		||||
        consistent with the single meshes.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # Load cow mesh.
 | 
			
		||||
        DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        cow_obj = DATA_DIR / "cow_mesh/cow.obj"
 | 
			
		||||
 | 
			
		||||
        cow_mesh = load_objs_as_meshes([cow_obj])
 | 
			
		||||
        cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
        # Join a batch of three single meshes and check that the values are consistent
 | 
			
		||||
        # with the individual meshes.
 | 
			
		||||
        cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
 | 
			
		||||
 | 
			
		||||
        def check_item(x, y, offset):
 | 
			
		||||
            self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
 | 
			
		||||
 | 
			
		||||
        check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
 | 
			
		||||
        check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
 | 
			
		||||
 | 
			
		||||
        # Test the joining of meshes of different sizes.
 | 
			
		||||
        teapot_obj = DATA_DIR / "teapot.obj"
 | 
			
		||||
        teapot_mesh = load_objs_as_meshes([teapot_obj])
 | 
			
		||||
        teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
 | 
			
		||||
        mix_mesh = join_mesh([cow_mesh, teapot_mesh])
 | 
			
		||||
        mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
 | 
			
		||||
        self.assertEqual(len(mix_mesh), 1)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
 | 
			
		||||
        self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
 | 
			
		||||
        self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
 | 
			
		||||
        self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
 | 
			
		||||
        # Check error raised if all meshes in the batch don't have the same texture type
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "same type of texture"):
 | 
			
		||||
            join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
 | 
			
		||||
 | 
			
		||||
@ -11,10 +11,11 @@ import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.io import load_objs_as_meshes
 | 
			
		||||
from pytorch3d.io import load_obj
 | 
			
		||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.lighting import PointLights
 | 
			
		||||
from pytorch3d.renderer.materials import Materials
 | 
			
		||||
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
 | 
			
		||||
from pytorch3d.renderer.mesh.shader import (
 | 
			
		||||
@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import (
 | 
			
		||||
    SoftSilhouetteShader,
 | 
			
		||||
    TexturedSoftPhongShader,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.mesh.texturing import Textures
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes, join_mesh
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        sphere_mesh = ico_sphere(5, device)
 | 
			
		||||
        verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
        faces_padded = sphere_mesh.faces_padded()
 | 
			
		||||
        textures = Textures(verts_rgb=torch.ones_like(verts_padded))
 | 
			
		||||
        feats = torch.ones_like(verts_padded, device=device)
 | 
			
		||||
        textures = TexturesVertex(verts_features=feats)
 | 
			
		||||
        sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            filename = "simple_sphere_light_%s%s.png" % (name, postfix)
 | 
			
		||||
            image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                filename = "DEBUG_%s" % filename
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        Test a mesh with vertex textures can be extended to form a batch, and
 | 
			
		||||
        is rendered correctly with Phong, Gouraud and Flat Shaders.
 | 
			
		||||
        """
 | 
			
		||||
        batch_size = 20
 | 
			
		||||
        batch_size = 5
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        # Init mesh with vertex textures.
 | 
			
		||||
        sphere_meshes = ico_sphere(5, device).extend(batch_size)
 | 
			
		||||
        verts_padded = sphere_meshes.verts_padded()
 | 
			
		||||
        faces_padded = sphere_meshes.faces_padded()
 | 
			
		||||
        textures = Textures(verts_rgb=torch.ones_like(verts_padded))
 | 
			
		||||
        feats = torch.ones_like(verts_padded, device=device)
 | 
			
		||||
        textures = TexturesVertex(verts_features=feats)
 | 
			
		||||
        sphere_meshes = Meshes(
 | 
			
		||||
            verts=verts_padded, faces=faces_padded, textures=textures
 | 
			
		||||
        )
 | 
			
		||||
@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            )
 | 
			
		||||
            for i in range(batch_size):
 | 
			
		||||
                rgb = images[i, ..., :3].squeeze().cpu()
 | 
			
		||||
                if i == 0 and DEBUG:
 | 
			
		||||
                    filename = "DEBUG_simple_sphere_batched_%s.png" % name
 | 
			
		||||
                    Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                        DATA_DIR / filename
 | 
			
		||||
                    )
 | 
			
		||||
                self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_silhouette_with_grad(self):
 | 
			
		||||
@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        with Image.open(image_ref_filename) as raw_image_ref:
 | 
			
		||||
            image_ref = torch.from_numpy(np.array(raw_image_ref))
 | 
			
		||||
 | 
			
		||||
        image_ref = image_ref.to(dtype=torch.float32) / 255.0
 | 
			
		||||
        self.assertClose(alpha, image_ref, atol=0.055)
 | 
			
		||||
 | 
			
		||||
@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        obj_filename = obj_dir / "cow_mesh/cow.obj"
 | 
			
		||||
 | 
			
		||||
        # Load mesh + texture
 | 
			
		||||
        mesh = load_objs_as_meshes([obj_filename], device=device)
 | 
			
		||||
        verts, faces, aux = load_obj(
 | 
			
		||||
            obj_filename, device=device, load_textures=True, texture_wrap=None
 | 
			
		||||
        )
 | 
			
		||||
        tex_map = list(aux.texture_images.values())[0]
 | 
			
		||||
        tex_map = tex_map[None, ...].to(faces.textures_idx.device)
 | 
			
		||||
        textures = TexturesUV(
 | 
			
		||||
            maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
 | 
			
		||||
        )
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0, 0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                Meshes(verts=verts, faces=sphere_list[i].faces_padded())
 | 
			
		||||
            )
 | 
			
		||||
        joined_sphere_mesh = join_mesh(sphere_mesh_list)
 | 
			
		||||
        joined_sphere_mesh.textures = Textures(
 | 
			
		||||
            verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
 | 
			
		||||
        joined_sphere_mesh.textures = TexturesVertex(
 | 
			
		||||
            verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                )
 | 
			
		||||
            image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_texture_map_atlas(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        obj_filename = obj_dir / "cow_mesh/cow.obj"
 | 
			
		||||
 | 
			
		||||
        # Load mesh and texture as a per face texture atlas.
 | 
			
		||||
        verts, faces, aux = load_obj(
 | 
			
		||||
            obj_filename,
 | 
			
		||||
            device=device,
 | 
			
		||||
            load_textures=True,
 | 
			
		||||
            create_texture_atlas=True,
 | 
			
		||||
            texture_atlas_size=8,
 | 
			
		||||
            texture_wrap=None,
 | 
			
		||||
        )
 | 
			
		||||
        mesh = Meshes(
 | 
			
		||||
            verts=[verts],
 | 
			
		||||
            faces=[faces.verts_idx],
 | 
			
		||||
            textures=TexturesAtlas(atlas=[aux.texture_atlas]),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0, 0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
        materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0)
 | 
			
		||||
        lights = PointLights(device=device)
 | 
			
		||||
 | 
			
		||||
        # Place light behind the cow in world space. The front of
 | 
			
		||||
        # the cow is facing the -z direction.
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
 | 
			
		||||
 | 
			
		||||
        # The HardPhongShader can be used directly with atlas textures.
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        images = renderer(mesh)
 | 
			
		||||
        rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
@ -7,14 +7,376 @@ import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
 | 
			
		||||
from pytorch3d.renderer.mesh.texturing import interpolate_texture_map
 | 
			
		||||
from pytorch3d.structures import Meshes, Textures
 | 
			
		||||
from pytorch3d.structures.utils import list_to_padded
 | 
			
		||||
from pytorch3d.renderer.mesh.textures import (
 | 
			
		||||
    TexturesAtlas,
 | 
			
		||||
    TexturesUV,
 | 
			
		||||
    TexturesVertex,
 | 
			
		||||
    _list_to_padded_wrapper,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
 | 
			
		||||
from test_meshes import TestMeshes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTexturing(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_interpolate_texture_map(self):
 | 
			
		||||
def tryindex(self, index, tex, meshes, source):
 | 
			
		||||
    tex2 = tex[index]
 | 
			
		||||
    meshes2 = meshes[index]
 | 
			
		||||
    tex_from_meshes = meshes2.textures
 | 
			
		||||
    for item in source:
 | 
			
		||||
        basic = source[item][index]
 | 
			
		||||
        from_texture = getattr(tex2, item + "_padded")()
 | 
			
		||||
        from_meshes = getattr(tex_from_meshes, item + "_padded")()
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            basic = basic[None]
 | 
			
		||||
 | 
			
		||||
        if len(basic) == 0:
 | 
			
		||||
            self.assertEquals(len(from_texture), 0)
 | 
			
		||||
            self.assertEquals(len(from_meshes), 0)
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertClose(basic, from_texture)
 | 
			
		||||
            self.assertClose(basic, from_meshes)
 | 
			
		||||
            self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim)
 | 
			
		||||
            item_list = getattr(tex_from_meshes, item + "_list")()
 | 
			
		||||
            self.assertEqual(basic.shape[0], len(item_list))
 | 
			
		||||
            for i, elem in enumerate(item_list):
 | 
			
		||||
                self.assertClose(elem, basic[i])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_sample_vertex_textures(self):
 | 
			
		||||
        """
 | 
			
		||||
        This tests both interpolate_vertex_colors as well as
 | 
			
		||||
        interpolate_face_attributes.
 | 
			
		||||
        """
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        vert_tex = torch.tensor(
 | 
			
		||||
            [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
        verts_features = vert_tex
 | 
			
		||||
        tex = TexturesVertex(verts_features=[verts_features])
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
            [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
        expected_vals = torch.tensor(
 | 
			
		||||
            [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
        fragments = Fragments(
 | 
			
		||||
            pix_to_face=pix_to_face,
 | 
			
		||||
            bary_coords=barycentric_coords,
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        # sample_textures calls interpolate_vertex_colors
 | 
			
		||||
        texels = mesh.sample_textures(fragments)
 | 
			
		||||
        self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
 | 
			
		||||
 | 
			
		||||
    def test_sample_vertex_textures_grad(self):
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        vert_tex = torch.tensor(
 | 
			
		||||
            [[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            requires_grad=True,
 | 
			
		||||
        )
 | 
			
		||||
        verts_features = vert_tex
 | 
			
		||||
        tex = TexturesVertex(verts_features=[verts_features])
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
            [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
        fragments = Fragments(
 | 
			
		||||
            pix_to_face=pix_to_face,
 | 
			
		||||
            bary_coords=barycentric_coords,
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        grad_vert_tex = torch.tensor(
 | 
			
		||||
            [[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        texels = mesh.sample_textures(fragments)
 | 
			
		||||
        texels.sum().backward()
 | 
			
		||||
        self.assertTrue(hasattr(vert_tex, "grad"))
 | 
			
		||||
        self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
 | 
			
		||||
 | 
			
		||||
    def test_textures_vertex_init_fail(self):
 | 
			
		||||
        # Incorrect sized tensors
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_features"):
 | 
			
		||||
            TexturesVertex(verts_features=torch.rand(size=(5, 10)))
 | 
			
		||||
 | 
			
		||||
        # Not a list or a tensor
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_features"):
 | 
			
		||||
            TexturesVertex(verts_features=(1, 1, 1))
 | 
			
		||||
 | 
			
		||||
    def test_clone(self):
 | 
			
		||||
        tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
 | 
			
		||||
        tex_cloned = tex.clone()
 | 
			
		||||
        self.assertSeparate(
 | 
			
		||||
            tex._verts_features_padded, tex_cloned._verts_features_padded
 | 
			
		||||
        )
 | 
			
		||||
        self.assertSeparate(tex.valid, tex_cloned.valid)
 | 
			
		||||
 | 
			
		||||
    def test_extend(self):
 | 
			
		||||
        B = 10
 | 
			
		||||
        mesh = TestMeshes.init_mesh(B, 30, 50)
 | 
			
		||||
        V = mesh._V
 | 
			
		||||
        tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3)))
 | 
			
		||||
        tex_mesh = Meshes(
 | 
			
		||||
            verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
 | 
			
		||||
        )
 | 
			
		||||
        N = 20
 | 
			
		||||
        new_mesh = tex_mesh.extend(N)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(tex_mesh) * N, len(new_mesh))
 | 
			
		||||
 | 
			
		||||
        tex_init = tex_mesh.textures
 | 
			
		||||
        new_tex = new_mesh.textures
 | 
			
		||||
 | 
			
		||||
        for i in range(len(tex_mesh)):
 | 
			
		||||
            for n in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.verts_features_list()[i],
 | 
			
		||||
                    new_tex.verts_features_list()[i * N + n],
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init._num_faces_per_mesh[i],
 | 
			
		||||
                    new_tex._num_faces_per_mesh[i * N + n],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.assertAllSeparate(
 | 
			
		||||
            [tex_init.verts_features_padded(), new_tex.verts_features_padded()]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            tex_mesh.extend(N=-1)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed(self):
 | 
			
		||||
        # Case where each face in the mesh has 3 unique uv vertex indices
 | 
			
		||||
        # - i.e. even if a vertex is shared between multiple faces it will
 | 
			
		||||
        # have a unique uv coordinate for each face.
 | 
			
		||||
        num_verts_per_mesh = [9, 6]
 | 
			
		||||
        D = 10
 | 
			
		||||
        verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh]
 | 
			
		||||
        verts_features_packed = list_to_packed(verts_features_list)[0]
 | 
			
		||||
        verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh)
 | 
			
		||||
        tex = TexturesVertex(verts_features=verts_features_list)
 | 
			
		||||
 | 
			
		||||
        # This is set inside Meshes when textures is passed as an input.
 | 
			
		||||
        # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
 | 
			
		||||
        tex1 = tex.clone()
 | 
			
		||||
        tex1._num_verts_per_mesh = num_verts_per_mesh
 | 
			
		||||
        verts_packed = tex1.verts_features_packed()
 | 
			
		||||
        verts_verts_list = tex1.verts_features_list()
 | 
			
		||||
        verts_padded = tex1.verts_features_padded()
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(verts_verts_list, verts_features_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D))
 | 
			
		||||
        self.assertTrue(verts_padded.shape == (2, 9, D))
 | 
			
		||||
 | 
			
		||||
        # Case where num_verts_per_mesh is not set and textures
 | 
			
		||||
        # are initialized with a padded tensor.
 | 
			
		||||
        tex2 = TexturesVertex(verts_features=verts_padded)
 | 
			
		||||
        verts_packed = tex2.verts_features_packed()
 | 
			
		||||
        verts_list = tex2.verts_features_list()
 | 
			
		||||
 | 
			
		||||
        # Packed is just flattened padded as num_verts_per_mesh
 | 
			
		||||
        # has not been provided.
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (9 * 2, D))
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)):
 | 
			
		||||
            n = num_verts_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        V = 20
 | 
			
		||||
        source = {"verts_features": torch.randn(size=(N, 10, 128))}
 | 
			
		||||
        tex = TexturesVertex(verts_features=source["verts_features"])
 | 
			
		||||
 | 
			
		||||
        verts = torch.rand(size=(N, V, 3))
 | 
			
		||||
        faces = torch.randint(size=(N, 10, 3), high=V)
 | 
			
		||||
        meshes = Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
        tryindex(self, 2, tex, meshes, source)
 | 
			
		||||
        tryindex(self, slice(0, 2, 1), tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        tryindex(self, [2, 4], tex, meshes, source)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_sample_texture_atlas(self):
 | 
			
		||||
        N, F, R = 1, 2, 2
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        faces_atlas = torch.rand(size=(N, F, R, R, 3))
 | 
			
		||||
        tex = TexturesAtlas(atlas=faces_atlas)
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
            [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
        expected_vals = torch.tensor(
 | 
			
		||||
            [[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
 | 
			
		||||
        )
 | 
			
		||||
        expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32)
 | 
			
		||||
        expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...]
 | 
			
		||||
        expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...]
 | 
			
		||||
 | 
			
		||||
        fragments = Fragments(
 | 
			
		||||
            pix_to_face=pix_to_face,
 | 
			
		||||
            bary_coords=barycentric_coords,
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        texels = mesh.textures.sample_textures(fragments)
 | 
			
		||||
        self.assertTrue(torch.allclose(texels, expected_vals))
 | 
			
		||||
 | 
			
		||||
    def test_textures_atlas_grad(self):
 | 
			
		||||
        N, F, R = 1, 2, 2
 | 
			
		||||
        verts = torch.randn((4, 3), dtype=torch.float32)
 | 
			
		||||
        faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
 | 
			
		||||
        faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True)
 | 
			
		||||
        tex = TexturesAtlas(atlas=faces_atlas)
 | 
			
		||||
        mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
 | 
			
		||||
        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
            [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
        fragments = Fragments(
 | 
			
		||||
            pix_to_face=pix_to_face,
 | 
			
		||||
            bary_coords=barycentric_coords,
 | 
			
		||||
            zbuf=torch.ones_like(pix_to_face),
 | 
			
		||||
            dists=torch.ones_like(pix_to_face),
 | 
			
		||||
        )
 | 
			
		||||
        texels = mesh.textures.sample_textures(fragments)
 | 
			
		||||
        grad_tex = torch.rand_like(texels)
 | 
			
		||||
        grad_expected = torch.zeros_like(faces_atlas)
 | 
			
		||||
        grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :]
 | 
			
		||||
        grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :]
 | 
			
		||||
        texels.backward(grad_tex)
 | 
			
		||||
        self.assertTrue(hasattr(faces_atlas, "grad"))
 | 
			
		||||
        self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected))
 | 
			
		||||
 | 
			
		||||
    def test_textures_atlas_init_fail(self):
 | 
			
		||||
        # Incorrect sized tensors
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "atlas"):
 | 
			
		||||
            TexturesAtlas(atlas=torch.rand(size=(5, 10, 3)))
 | 
			
		||||
 | 
			
		||||
        # Not a list or a tensor
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "atlas"):
 | 
			
		||||
            TexturesAtlas(atlas=(1, 1, 1))
 | 
			
		||||
 | 
			
		||||
    def test_clone(self):
 | 
			
		||||
        tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
 | 
			
		||||
        tex_cloned = tex.clone()
 | 
			
		||||
        self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
 | 
			
		||||
        self.assertSeparate(tex.valid, tex_cloned.valid)
 | 
			
		||||
 | 
			
		||||
    def test_extend(self):
 | 
			
		||||
        B = 10
 | 
			
		||||
        mesh = TestMeshes.init_mesh(B, 30, 50)
 | 
			
		||||
        F = mesh._F
 | 
			
		||||
        tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3)))
 | 
			
		||||
        tex_mesh = Meshes(
 | 
			
		||||
            verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
 | 
			
		||||
        )
 | 
			
		||||
        N = 20
 | 
			
		||||
        new_mesh = tex_mesh.extend(N)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(tex_mesh) * N, len(new_mesh))
 | 
			
		||||
 | 
			
		||||
        tex_init = tex_mesh.textures
 | 
			
		||||
        new_tex = new_mesh.textures
 | 
			
		||||
 | 
			
		||||
        for i in range(len(tex_mesh)):
 | 
			
		||||
            for n in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init._num_faces_per_mesh[i],
 | 
			
		||||
                    new_tex._num_faces_per_mesh[i * N + n],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()])
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            tex_mesh.extend(N=-1)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed(self):
 | 
			
		||||
        # Case where each face in the mesh has 3 unique uv vertex indices
 | 
			
		||||
        # - i.e. even if a vertex is shared between multiple faces it will
 | 
			
		||||
        # have a unique uv coordinate for each face.
 | 
			
		||||
        R = 2
 | 
			
		||||
        N = 20
 | 
			
		||||
        num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30)
 | 
			
		||||
        atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh]
 | 
			
		||||
        tex = TexturesAtlas(atlas=atlas_list)
 | 
			
		||||
 | 
			
		||||
        # This is set inside Meshes when textures is passed as an input.
 | 
			
		||||
        # Here we set _num_faces_per_mesh explicity.
 | 
			
		||||
        tex1 = tex.clone()
 | 
			
		||||
        tex1._num_faces_per_mesh = num_faces_per_mesh.tolist()
 | 
			
		||||
        atlas_packed = tex1.atlas_packed()
 | 
			
		||||
        atlas_list_new = tex1.atlas_list()
 | 
			
		||||
        atlas_padded = tex1.atlas_padded()
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(atlas_list_new, atlas_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        sum_F = num_faces_per_mesh.sum()
 | 
			
		||||
        max_F = num_faces_per_mesh.max().item()
 | 
			
		||||
        self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3))
 | 
			
		||||
        self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3))
 | 
			
		||||
 | 
			
		||||
        # Case where num_faces_per_mesh is not set and textures
 | 
			
		||||
        # are initialized with a padded tensor.
 | 
			
		||||
        atlas_list_padded = _list_to_padded_wrapper(atlas_list)
 | 
			
		||||
        tex2 = TexturesAtlas(atlas=atlas_list_padded)
 | 
			
		||||
        atlas_packed = tex2.atlas_packed()
 | 
			
		||||
        atlas_list_new = tex2.atlas_list()
 | 
			
		||||
 | 
			
		||||
        # Packed is just flattened padded as num_faces_per_mesh
 | 
			
		||||
        # has not been provided.
 | 
			
		||||
        self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3))
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)):
 | 
			
		||||
            n = num_faces_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        V = 20
 | 
			
		||||
        source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
 | 
			
		||||
        tex = TexturesAtlas(atlas=source["atlas"])
 | 
			
		||||
 | 
			
		||||
        verts = torch.rand(size=(N, V, 3))
 | 
			
		||||
        faces = torch.randint(size=(N, 10, 3), high=V)
 | 
			
		||||
        meshes = Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
        tryindex(self, 2, tex, meshes, source)
 | 
			
		||||
        tryindex(self, slice(0, 2, 1), tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        tryindex(self, [2, 4], tex, meshes, source)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_sample_textures_uv(self):
 | 
			
		||||
        barycentric_coords = torch.tensor(
 | 
			
		||||
            [[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
 | 
			
		||||
        ).view(1, 1, 1, 2, -1)
 | 
			
		||||
@ -38,11 +400,11 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            zbuf=pix_to_face,
 | 
			
		||||
            dists=pix_to_face,
 | 
			
		||||
        )
 | 
			
		||||
        tex = Textures(
 | 
			
		||||
            maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
 | 
			
		||||
        meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
 | 
			
		||||
        texels = interpolate_texture_map(fragments, meshes)
 | 
			
		||||
        mesh_textures = meshes.textures
 | 
			
		||||
        texels = mesh_textures.sample_textures(fragments)
 | 
			
		||||
 | 
			
		||||
        # Expected output
 | 
			
		||||
        pixel_uvs = interpolated_uvs * 2.0 - 1.0
 | 
			
		||||
@ -53,190 +415,92 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
 | 
			
		||||
        self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
 | 
			
		||||
 | 
			
		||||
    def test_init_rgb_uv_fail(self):
 | 
			
		||||
        V = 20
 | 
			
		||||
    def test_textures_uv_init_fail(self):
 | 
			
		||||
        # Maps has wrong shape
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "maps"):
 | 
			
		||||
            Textures(
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3, 4)),
 | 
			
		||||
                faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
 | 
			
		||||
                verts_uvs=torch.ones((5, V, 2)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # faces_uvs has wrong shape
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "faces_uvs"):
 | 
			
		||||
            Textures(
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
 | 
			
		||||
                verts_uvs=torch.ones((5, V, 2)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # verts_uvs has wrong shape
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_uvs"):
 | 
			
		||||
            Textures(
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
 | 
			
		||||
                verts_uvs=torch.ones((5, V, 2, 3)),
 | 
			
		||||
            )
 | 
			
		||||
        # verts_rgb has wrong shape
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_rgb"):
 | 
			
		||||
            Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
 | 
			
		||||
 | 
			
		||||
        # maps provided without verts/faces uvs
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"):
 | 
			
		||||
            Textures(maps=torch.ones((5, 16, 16, 3)))
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed(self):
 | 
			
		||||
        N = 2
 | 
			
		||||
        # Case where each face in the mesh has 3 unique uv vertex indices
 | 
			
		||||
        # - i.e. even if a vertex is shared between multiple faces it will
 | 
			
		||||
        # have a unique uv coordinate for each face.
 | 
			
		||||
        faces_uvs_list = [
 | 
			
		||||
            torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
 | 
			
		||||
            torch.tensor([[0, 1, 2], [3, 4, 5]]),
 | 
			
		||||
        ]  # (N, 3, 3)
 | 
			
		||||
        verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
 | 
			
		||||
        faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
 | 
			
		||||
        verts_uvs_padded = list_to_padded(verts_uvs_list)
 | 
			
		||||
        tex = Textures(
 | 
			
		||||
            maps=torch.ones((N, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=faces_uvs_padded,
 | 
			
		||||
            verts_uvs=verts_uvs_padded,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # This is set inside Meshes when textures is passed as an input.
 | 
			
		||||
        # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
 | 
			
		||||
        tex1 = tex.clone()
 | 
			
		||||
        tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
 | 
			
		||||
        tex1._num_verts_per_mesh = torch.tensor([5, 4])
 | 
			
		||||
        faces_packed = tex1.faces_uvs_packed()
 | 
			
		||||
        verts_packed = tex1.verts_uvs_packed()
 | 
			
		||||
        faces_list = tex1.faces_uvs_list()
 | 
			
		||||
        verts_list = tex1.verts_uvs_list()
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(faces_uvs_list, faces_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
 | 
			
		||||
            idx = f.unique()
 | 
			
		||||
            self.assertTrue((v1[idx] == v2).all().item())
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(faces_packed.shape == (3 + 2, 3))
 | 
			
		||||
 | 
			
		||||
        # verts_packed is just flattened verts_padded.
 | 
			
		||||
        # split sizes are not used for verts_uvs.
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (9 * 2, 2))
 | 
			
		||||
 | 
			
		||||
        # Case where num_faces_per_mesh is not set
 | 
			
		||||
        tex2 = tex.clone()
 | 
			
		||||
        faces_packed = tex2.faces_uvs_packed()
 | 
			
		||||
        verts_packed = tex2.verts_uvs_packed()
 | 
			
		||||
        faces_list = tex2.faces_uvs_list()
 | 
			
		||||
        verts_list = tex2.verts_uvs_list()
 | 
			
		||||
 | 
			
		||||
        # Packed is just flattened padded as num_faces_per_mesh
 | 
			
		||||
        # has not been provided.
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (9 * 2, 2))
 | 
			
		||||
        self.assertTrue(faces_packed.shape == (3 * 2, 3))
 | 
			
		||||
 | 
			
		||||
        for i in range(N):
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                (faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item()
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2, 3)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for i in range(N):
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                (verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item()
 | 
			
		||||
        # verts has different batch dim to faces
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_uvs"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(8, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # maps has different batch dim to faces
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "maps"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((8, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # verts on different device to faces
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "verts_uvs"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # maps on different device to faces
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "map"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3), device="cuda"),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_clone(self):
 | 
			
		||||
        V = 20
 | 
			
		||||
        tex = Textures(
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
 | 
			
		||||
            verts_uvs=torch.ones((5, V, 2)),
 | 
			
		||||
            faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
        )
 | 
			
		||||
        tex_cloned = tex.clone()
 | 
			
		||||
        self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
 | 
			
		||||
        self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
 | 
			
		||||
        self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        V = 20
 | 
			
		||||
        source = {
 | 
			
		||||
            "maps": torch.rand(size=(N, 16, 16, 3)),
 | 
			
		||||
            "faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
 | 
			
		||||
            "verts_uvs": torch.rand((N, V, 2)),
 | 
			
		||||
        }
 | 
			
		||||
        tex = Textures(
 | 
			
		||||
            maps=source["maps"],
 | 
			
		||||
            faces_uvs=source["faces_uvs"],
 | 
			
		||||
            verts_uvs=source["verts_uvs"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        verts = torch.rand(size=(N, V, 3))
 | 
			
		||||
        faces = torch.randint(size=(N, 10, 3), high=V)
 | 
			
		||||
        meshes = Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
        def tryindex(index):
 | 
			
		||||
            tex2 = tex[index]
 | 
			
		||||
            meshes2 = meshes[index]
 | 
			
		||||
            tex_from_meshes = meshes2.textures
 | 
			
		||||
            for item in source:
 | 
			
		||||
                basic = source[item][index]
 | 
			
		||||
                from_texture = getattr(tex2, item + "_padded")()
 | 
			
		||||
                from_meshes = getattr(tex_from_meshes, item + "_padded")()
 | 
			
		||||
                if isinstance(index, int):
 | 
			
		||||
                    basic = basic[None]
 | 
			
		||||
                self.assertClose(basic, from_texture)
 | 
			
		||||
                self.assertClose(basic, from_meshes)
 | 
			
		||||
                self.assertEqual(
 | 
			
		||||
                    from_texture.ndim, getattr(tex, item + "_padded")().ndim
 | 
			
		||||
                )
 | 
			
		||||
                if item == "faces_uvs":
 | 
			
		||||
                    faces_uvs_list = tex_from_meshes.faces_uvs_list()
 | 
			
		||||
                    self.assertEqual(basic.shape[0], len(faces_uvs_list))
 | 
			
		||||
                    for i, faces_uvs in enumerate(faces_uvs_list):
 | 
			
		||||
                        self.assertClose(faces_uvs, basic[i])
 | 
			
		||||
 | 
			
		||||
        tryindex(2)
 | 
			
		||||
        tryindex(slice(0, 2, 1))
 | 
			
		||||
        index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(index)
 | 
			
		||||
        index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(index)
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64)
 | 
			
		||||
        tryindex(index)
 | 
			
		||||
        tryindex([2, 4])
 | 
			
		||||
 | 
			
		||||
    def test_to(self):
 | 
			
		||||
        V = 20
 | 
			
		||||
        tex = Textures(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
 | 
			
		||||
            verts_uvs=torch.ones((5, V, 2)),
 | 
			
		||||
        )
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        tex = tex.to(device)
 | 
			
		||||
        self.assertTrue(tex._faces_uvs_padded.device == device)
 | 
			
		||||
        self.assertTrue(tex._verts_uvs_padded.device == device)
 | 
			
		||||
        self.assertTrue(tex._maps_padded.device == device)
 | 
			
		||||
        self.assertSeparate(tex.valid, tex_cloned.valid)
 | 
			
		||||
 | 
			
		||||
    def test_extend(self):
 | 
			
		||||
        B = 10
 | 
			
		||||
        B = 5
 | 
			
		||||
        mesh = TestMeshes.init_mesh(B, 30, 50)
 | 
			
		||||
        V = mesh._V
 | 
			
		||||
        F = mesh._F
 | 
			
		||||
 | 
			
		||||
        # 1. Texture uvs
 | 
			
		||||
        tex_uv = Textures(
 | 
			
		||||
            maps=torch.randn((B, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
 | 
			
		||||
            verts_uvs=torch.randn((B, V, 2)),
 | 
			
		||||
        num_faces = mesh.num_faces_per_mesh()
 | 
			
		||||
        num_verts = mesh.num_verts_per_mesh()
 | 
			
		||||
        faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces]
 | 
			
		||||
        verts_uvs_list = [torch.rand(v, 2) for v in num_verts]
 | 
			
		||||
        tex_uv = TexturesUV(
 | 
			
		||||
            maps=torch.ones((B, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=faces_uvs_list,
 | 
			
		||||
            verts_uvs=verts_uvs_list,
 | 
			
		||||
        )
 | 
			
		||||
        tex_mesh = Meshes(
 | 
			
		||||
            verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
 | 
			
		||||
            verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv
 | 
			
		||||
        )
 | 
			
		||||
        N = 20
 | 
			
		||||
        N = 2
 | 
			
		||||
        new_mesh = tex_mesh.extend(N)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(tex_mesh) * N, len(new_mesh))
 | 
			
		||||
@ -246,56 +510,142 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        for i in range(len(tex_mesh)):
 | 
			
		||||
            for n in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
 | 
			
		||||
                    tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init._num_faces_per_mesh[i],
 | 
			
		||||
                    new_tex._num_faces_per_mesh[i * N + n],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.assertAllSeparate(
 | 
			
		||||
            [
 | 
			
		||||
                tex_init.faces_uvs_padded(),
 | 
			
		||||
                new_tex.faces_uvs_padded(),
 | 
			
		||||
                tex_init.faces_uvs_packed(),
 | 
			
		||||
                new_tex.faces_uvs_packed(),
 | 
			
		||||
                tex_init.verts_uvs_padded(),
 | 
			
		||||
                new_tex.verts_uvs_padded(),
 | 
			
		||||
                tex_init.verts_uvs_packed(),
 | 
			
		||||
                new_tex.verts_uvs_packed(),
 | 
			
		||||
                tex_init.maps_padded(),
 | 
			
		||||
                new_tex.maps_padded(),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(new_tex.verts_rgb_list())
 | 
			
		||||
        self.assertIsNone(new_tex.verts_rgb_padded())
 | 
			
		||||
        self.assertIsNone(new_tex.verts_rgb_packed())
 | 
			
		||||
 | 
			
		||||
        # 2. Texture vertex RGB
 | 
			
		||||
        tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
 | 
			
		||||
        tex_mesh_rgb = Meshes(
 | 
			
		||||
            verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb
 | 
			
		||||
        )
 | 
			
		||||
        N = 20
 | 
			
		||||
        new_mesh_rgb = tex_mesh_rgb.extend(N)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
 | 
			
		||||
 | 
			
		||||
        tex_init = tex_mesh_rgb.textures
 | 
			
		||||
        new_tex = new_mesh_rgb.textures
 | 
			
		||||
 | 
			
		||||
        for i in range(len(tex_mesh_rgb)):
 | 
			
		||||
            for n in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n]
 | 
			
		||||
                )
 | 
			
		||||
        self.assertAllSeparate(
 | 
			
		||||
            [tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertIsNone(new_tex.verts_uvs_padded())
 | 
			
		||||
        self.assertIsNone(new_tex.verts_uvs_list())
 | 
			
		||||
        self.assertIsNone(new_tex.verts_uvs_packed())
 | 
			
		||||
        self.assertIsNone(new_tex.faces_uvs_padded())
 | 
			
		||||
        self.assertIsNone(new_tex.faces_uvs_list())
 | 
			
		||||
        self.assertIsNone(new_tex.faces_uvs_packed())
 | 
			
		||||
 | 
			
		||||
        # 3. Error
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            tex_mesh.extend(N=-1)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed(self):
 | 
			
		||||
        # Case where each face in the mesh has 3 unique uv vertex indices
 | 
			
		||||
        # - i.e. even if a vertex is shared between multiple faces it will
 | 
			
		||||
        # have a unique uv coordinate for each face.
 | 
			
		||||
        N = 2
 | 
			
		||||
        faces_uvs_list = [
 | 
			
		||||
            torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
 | 
			
		||||
            torch.tensor([[0, 1, 2], [3, 4, 5]]),
 | 
			
		||||
        ]  # (N, 3, 3)
 | 
			
		||||
        verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
 | 
			
		||||
 | 
			
		||||
        num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
 | 
			
		||||
        num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((N, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=faces_uvs_list,
 | 
			
		||||
            verts_uvs=verts_uvs_list,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # This is set inside Meshes when textures is passed as an input.
 | 
			
		||||
        # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
 | 
			
		||||
        tex1 = tex.clone()
 | 
			
		||||
        tex1._num_faces_per_mesh = num_faces_per_mesh
 | 
			
		||||
        tex1._num_verts_per_mesh = num_verts_per_mesh
 | 
			
		||||
        verts_packed = tex1.verts_uvs_packed()
 | 
			
		||||
        verts_list = tex1.verts_uvs_list()
 | 
			
		||||
        verts_padded = tex1.verts_uvs_padded()
 | 
			
		||||
 | 
			
		||||
        faces_packed = tex1.faces_uvs_packed()
 | 
			
		||||
        faces_list = tex1.faces_uvs_list()
 | 
			
		||||
        faces_padded = tex1.faces_uvs_padded()
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(faces_list, faces_uvs_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(verts_list, verts_uvs_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(faces_packed.shape == (3 + 2, 3))
 | 
			
		||||
        self.assertTrue(faces_padded.shape == (2, 3, 3))
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (9 + 6, 2))
 | 
			
		||||
        self.assertTrue(verts_padded.shape == (2, 9, 2))
 | 
			
		||||
 | 
			
		||||
        # Case where num_faces_per_mesh is not set and faces_verts_uvs
 | 
			
		||||
        # are initialized with a padded tensor.
 | 
			
		||||
        tex2 = TexturesUV(
 | 
			
		||||
            maps=torch.ones((N, 16, 16, 3)),
 | 
			
		||||
            verts_uvs=verts_padded,
 | 
			
		||||
            faces_uvs=faces_padded,
 | 
			
		||||
        )
 | 
			
		||||
        faces_packed = tex2.faces_uvs_packed()
 | 
			
		||||
        faces_list = tex2.faces_uvs_list()
 | 
			
		||||
        verts_packed = tex2.verts_uvs_packed()
 | 
			
		||||
        verts_list = tex2.verts_uvs_list()
 | 
			
		||||
 | 
			
		||||
        # Packed is just flattened padded as num_faces_per_mesh
 | 
			
		||||
        # has not been provided.
 | 
			
		||||
        self.assertTrue(faces_packed.shape == (3 * 2, 3))
 | 
			
		||||
        self.assertTrue(verts_packed.shape == (9 * 2, 2))
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
 | 
			
		||||
            n = num_faces_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)):
 | 
			
		||||
            n = num_verts_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
    def test_to(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(5, 10, 3), high=15),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
        )
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        tex = tex.to(device)
 | 
			
		||||
        self.assertTrue(tex._faces_uvs_padded.device == device)
 | 
			
		||||
        self.assertTrue(tex._verts_uvs_padded.device == device)
 | 
			
		||||
        self.assertTrue(tex._maps_padded.device == device)
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        V = 20
 | 
			
		||||
        source = {
 | 
			
		||||
            "maps": torch.rand(size=(N, 1, 1, 3)),
 | 
			
		||||
            "faces_uvs": torch.randint(size=(N, 10, 3), high=V),
 | 
			
		||||
            "verts_uvs": torch.randn(size=(N, V, 2)),
 | 
			
		||||
        }
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=source["maps"],
 | 
			
		||||
            faces_uvs=source["faces_uvs"],
 | 
			
		||||
            verts_uvs=source["verts_uvs"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        verts = torch.rand(size=(N, V, 3))
 | 
			
		||||
        faces = torch.randint(size=(N, 10, 3), high=V)
 | 
			
		||||
        meshes = Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
        tryindex(self, 2, tex, meshes, source)
 | 
			
		||||
        tryindex(self, slice(0, 2, 1), tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        tryindex(self, [2, 4], tex, meshes, source)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user