mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Texturing API updates
Summary: A fairly big refactor of the texturing API with some breaking changes to how textures are defined. Main changes: - There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class: - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub. - has a `join_batch` method for joining multiple textures of the same type into a batch Reviewed By: gkioxari Differential Revision: D21067427 fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
parent
b73d3d6ed9
commit
a3932960b3
@ -520,6 +520,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"anp_metadata": {
|
||||
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb"
|
||||
},
|
||||
"bento_stylesheets": {
|
||||
"bento/extensions/flow/main.css": true,
|
||||
"bento/extensions/kernel_selector/main.css": true,
|
||||
@ -533,6 +536,9 @@
|
||||
"provenance": [],
|
||||
"toc_visible": true
|
||||
},
|
||||
"disseminate_notebook_info": {
|
||||
"backup_notebook_id": "1062179640844868"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "pytorch3d (local)",
|
||||
"language": "python",
|
||||
|
@ -84,7 +84,7 @@
|
||||
"from skimage.io import imread\n",
|
||||
"\n",
|
||||
"# Util function for loading meshes\n",
|
||||
"from pytorch3d.io import load_objs_as_meshes\n",
|
||||
"from pytorch3d.io import load_objs_as_meshes, load_obj\n",
|
||||
"\n",
|
||||
"# Data structures and functions for rendering\n",
|
||||
"from pytorch3d.structures import Meshes, Textures\n",
|
||||
@ -97,7 +97,7 @@
|
||||
" RasterizationSettings, \n",
|
||||
" MeshRenderer, \n",
|
||||
" MeshRasterizer, \n",
|
||||
" TexturedSoftPhongShader\n",
|
||||
" SoftPhongShader\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# add path for demo utils functions \n",
|
||||
@ -316,7 +316,7 @@
|
||||
" cameras=cameras, \n",
|
||||
" raster_settings=raster_settings\n",
|
||||
" ),\n",
|
||||
" shader=TexturedSoftPhongShader(\n",
|
||||
" shader=SoftPhongShader(\n",
|
||||
" device=device, \n",
|
||||
" cameras=cameras,\n",
|
||||
" lights=lights\n",
|
||||
@ -563,6 +563,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"anp_metadata": {
|
||||
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/render_textured_meshes.ipynb"
|
||||
},
|
||||
"bento_stylesheets": {
|
||||
"bento/extensions/flow/main.css": true,
|
||||
"bento/extensions/kernel_selector/main.css": true,
|
||||
@ -575,6 +578,9 @@
|
||||
"name": "render_textured_meshes.ipynb",
|
||||
"provenance": []
|
||||
},
|
||||
"disseminate_notebook_info": {
|
||||
"backup_notebook_id": "569222367081034"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "pytorch3d (local)",
|
||||
"language": "python",
|
||||
|
@ -13,8 +13,8 @@ from pytorch3d.renderer import (
|
||||
OpenGLPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
TexturesVertex,
|
||||
)
|
||||
from pytorch3d.structures import Textures
|
||||
|
||||
|
||||
class ShapeNetBase(torch.utils.data.Dataset):
|
||||
@ -113,8 +113,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
"""
|
||||
paths = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
|
||||
meshes.textures = Textures(
|
||||
verts_rgb=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
meshes.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
)
|
||||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||
renderer = MeshRenderer(
|
||||
|
@ -11,7 +11,8 @@ import numpy as np
|
||||
import torch
|
||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||
from pytorch3d.io.utils import _open_file
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||
|
||||
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
|
||||
@ -41,6 +42,10 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
Args:
|
||||
faces_indices: List of ints of indices.
|
||||
max_index: Max index for the face property.
|
||||
pad_value: if any of the face_indices are padded, specify
|
||||
the value of the padding (e.g. -1). This is only used
|
||||
for texture indices indices where there might
|
||||
not be texture information for all the faces.
|
||||
|
||||
Returns:
|
||||
faces_indices: List of ints of indices.
|
||||
@ -65,7 +70,9 @@ def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
faces_indices[mask] = pad_value
|
||||
|
||||
# Check indices are valid.
|
||||
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
|
||||
if torch.any(faces_indices >= max_index) or (
|
||||
pad_value is None and torch.any(faces_indices < 0)
|
||||
):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
|
||||
return faces_indices
|
||||
@ -227,7 +234,14 @@ def load_obj(
|
||||
)
|
||||
|
||||
|
||||
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
|
||||
def load_objs_as_meshes(
|
||||
files: list,
|
||||
device=None,
|
||||
load_textures: bool = True,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
):
|
||||
"""
|
||||
Load meshes from a list of .obj files using the load_obj function, and
|
||||
return them as a Meshes object. This only works for meshes which have a
|
||||
@ -246,18 +260,31 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
|
||||
"""
|
||||
mesh_list = []
|
||||
for f_obj in files:
|
||||
# TODO: update this function to support the two texturing options.
|
||||
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
|
||||
verts = verts.to(device)
|
||||
verts, faces, aux = load_obj(
|
||||
f_obj,
|
||||
load_textures=load_textures,
|
||||
create_texture_atlas=create_texture_atlas,
|
||||
texture_atlas_size=texture_atlas_size,
|
||||
texture_wrap=texture_wrap,
|
||||
)
|
||||
tex = None
|
||||
tex_maps = aux.texture_images
|
||||
if tex_maps is not None and len(tex_maps) > 0:
|
||||
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
|
||||
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
|
||||
image = list(tex_maps.values())[0].to(device)[None]
|
||||
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
|
||||
if create_texture_atlas:
|
||||
# TexturesAtlas type
|
||||
tex = TexturesAtlas(atlas=[aux.texture_atlas])
|
||||
else:
|
||||
# TexturesUV type
|
||||
tex_maps = aux.texture_images
|
||||
if tex_maps is not None and len(tex_maps) > 0:
|
||||
verts_uvs = aux.verts_uvs.to(device) # (V, 2)
|
||||
faces_uvs = faces.textures_idx.to(device) # (F, 3)
|
||||
image = list(tex_maps.values())[0].to(device)[None]
|
||||
tex = TexturesUV(
|
||||
verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image
|
||||
)
|
||||
|
||||
mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)], faces=[faces.verts_idx.to(device)], textures=tex
|
||||
)
|
||||
mesh_list.append(mesh)
|
||||
if len(mesh_list) == 1:
|
||||
return mesh_list[0]
|
||||
|
@ -28,11 +28,11 @@ from .mesh import (
|
||||
SoftGouraudShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
Textures,
|
||||
TexturesAtlas,
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
gouraud_shading,
|
||||
interpolate_face_attributes,
|
||||
interpolate_texture_map,
|
||||
interpolate_vertex_colors,
|
||||
phong_shading,
|
||||
rasterize_meshes,
|
||||
)
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors # isort:skip
|
||||
from .rasterize_meshes import rasterize_meshes
|
||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from .renderer import MeshRenderer
|
||||
from .shader import TexturedSoftPhongShader # DEPRECATED
|
||||
from .shader import (
|
||||
HardFlatShader,
|
||||
HardGouraudShader,
|
||||
@ -12,10 +12,10 @@ from .shader import (
|
||||
SoftGouraudShader,
|
||||
SoftPhongShader,
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from .shading import gouraud_shading, phong_shading
|
||||
from .utils import interpolate_face_attributes
|
||||
from .textures import Textures # DEPRECATED
|
||||
from .textures import TexturesAtlas, TexturesUV, TexturesVertex
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,7 +14,6 @@ from ..blending import (
|
||||
from ..lighting import PointLights
|
||||
from ..materials import Materials
|
||||
from .shading import flat_shading, gouraud_shading, phong_shading
|
||||
from .texturing import interpolate_texture_map, interpolate_vertex_colors
|
||||
|
||||
|
||||
# A Shader should take as input fragments from the output of rasterization
|
||||
@ -57,7 +57,7 @@ class HardPhongShader(nn.Module):
|
||||
or in the forward pass of HardPhongShader"
|
||||
raise ValueError(msg)
|
||||
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
@ -104,9 +104,11 @@ class SoftPhongShader(nn.Module):
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of SoftPhongShader"
|
||||
raise ValueError(msg)
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
colors = phong_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@ -115,7 +117,7 @@ class SoftPhongShader(nn.Module):
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = softmax_rgb_blend(colors, fragments, self.blend_params)
|
||||
images = softmax_rgb_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
|
||||
|
||||
@ -154,6 +156,12 @@ class HardGouraudShader(nn.Module):
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
|
||||
# As Gouraud shading applies the illumination to the vertex
|
||||
# colors, the interpolated pixel texture is calculated in the
|
||||
# shading step. In comparison, for Phong shading, the pixel
|
||||
# textures are computed first after which the illumination is
|
||||
# applied.
|
||||
pixel_colors = gouraud_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
@ -210,54 +218,25 @@ class SoftGouraudShader(nn.Module):
|
||||
return images
|
||||
|
||||
|
||||
class TexturedSoftPhongShader(nn.Module):
|
||||
def TexturedSoftPhongShader(
|
||||
device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
"""
|
||||
Per pixel lighting applied to a texture map. First interpolate the vertex
|
||||
uv coordinates and sample from a texture map. Then apply the lighting model
|
||||
using the interpolated coords and normals for each pixel.
|
||||
|
||||
The blending function returns the soft aggregated color using all
|
||||
the faces per pixel.
|
||||
|
||||
To use the default values, simply initialize the shader with the desired
|
||||
device e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
shader = TexturedPhongShader(device=torch.device("cuda:0"))
|
||||
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
|
||||
Preserving TexturedSoftPhongShader as a function for backwards compatibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
|
||||
):
|
||||
super().__init__()
|
||||
self.lights = lights if lights is not None else PointLights(device=device)
|
||||
self.materials = (
|
||||
materials if materials is not None else Materials(device=device)
|
||||
)
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of TexturedSoftPhongShader"
|
||||
raise ValueError(msg)
|
||||
texels = interpolate_texture_map(fragments, meshes)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
colors = phong_shading(
|
||||
meshes=meshes,
|
||||
fragments=fragments,
|
||||
texels=texels,
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
)
|
||||
images = softmax_rgb_blend(colors, fragments, blend_params)
|
||||
return images
|
||||
warnings.warn(
|
||||
"""TexturedSoftPhongShader is now deprecated;
|
||||
use SoftPhongShader instead.""",
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
return SoftPhongShader(
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
lights=lights,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
|
||||
|
||||
class HardFlatShader(nn.Module):
|
||||
@ -291,7 +270,7 @@ class HardFlatShader(nn.Module):
|
||||
msg = "Cameras must be specified either at initialization \
|
||||
or in the forward pass of HardFlatShader"
|
||||
raise ValueError(msg)
|
||||
texels = interpolate_vertex_colors(fragments, meshes)
|
||||
texels = meshes.sample_textures(fragments)
|
||||
lights = kwargs.get("lights", self.lights)
|
||||
materials = kwargs.get("materials", self.materials)
|
||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||
|
@ -6,6 +6,8 @@ from typing import Tuple
|
||||
import torch
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
|
||||
from .textures import TexturesVertex
|
||||
|
||||
|
||||
def _apply_lighting(
|
||||
points, normals, lights, cameras, materials
|
||||
@ -91,6 +93,9 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
|
||||
Then interpolate the vertex shaded colors using the barycentric coordinates
|
||||
to get a color per pixel.
|
||||
|
||||
Gouraud shading is only supported for meshes with texture type `TexturesVertex`.
|
||||
This is because the illumination is applied to the vertex colors.
|
||||
|
||||
Args:
|
||||
meshes: Batch of meshes
|
||||
fragments: Fragments named tuple with the outputs of rasterization
|
||||
@ -101,10 +106,13 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
|
||||
Returns:
|
||||
colors: (N, H, W, K, 3)
|
||||
"""
|
||||
if not isinstance(meshes.textures, TexturesVertex):
|
||||
raise ValueError("Mesh textures must be an instance of TexturesVertex")
|
||||
|
||||
faces = meshes.faces_packed() # (F, 3)
|
||||
verts = meshes.verts_packed()
|
||||
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
||||
vertex_colors = meshes.textures.verts_rgb_packed()
|
||||
verts = meshes.verts_packed() # (V, 3)
|
||||
verts_normals = meshes.verts_normals_packed() # (V, 3)
|
||||
verts_colors = meshes.textures.verts_features_packed() # (V, D)
|
||||
vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()
|
||||
|
||||
# Format properties of lights and materials so they are compatible
|
||||
@ -119,9 +127,10 @@ def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tens
|
||||
|
||||
# Calculate the illumination at each vertex
|
||||
ambient, diffuse, specular = _apply_lighting(
|
||||
verts, vertex_normals, lights, cameras, materials
|
||||
verts, verts_normals, lights, cameras, materials
|
||||
)
|
||||
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
|
||||
|
||||
verts_colors_shaded = verts_colors * (ambient + diffuse) + specular
|
||||
face_colors = verts_colors_shaded[faces]
|
||||
colors = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, face_colors
|
||||
|
1049
pytorch3d/renderer/mesh/textures.py
Normal file
1049
pytorch3d/renderer/mesh/textures.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,113 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
from pytorch3d.structures.textures import Textures
|
||||
|
||||
|
||||
def interpolate_texture_map(fragments, meshes) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate a 2D texture map using uv vertex texture coordinates for each
|
||||
face in the mesh. First interpolate the vertex uvs using barycentric coordinates
|
||||
for each pixel in the rasterized output. Then interpolate the texture map
|
||||
using the uv coordinate for each pixel.
|
||||
|
||||
Args:
|
||||
fragments:
|
||||
The outputs of rasterization. From this we use
|
||||
|
||||
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
meshes: Meshes representing a batch of meshes. It is expected that
|
||||
meshes has a textures attribute which is an instance of the
|
||||
Textures class.
|
||||
|
||||
Returns:
|
||||
texels: tensor of shape (N, H, W, K, C) giving the interpolated
|
||||
texture for each pixel in the rasterized image.
|
||||
"""
|
||||
if not isinstance(meshes.textures, Textures):
|
||||
msg = "Expected meshes.textures to be an instance of Textures; got %r"
|
||||
raise ValueError(msg % type(meshes.textures))
|
||||
|
||||
faces_uvs = meshes.textures.faces_uvs_packed()
|
||||
verts_uvs = meshes.textures.verts_uvs_packed()
|
||||
faces_verts_uvs = verts_uvs[faces_uvs]
|
||||
texture_maps = meshes.textures.maps_padded()
|
||||
|
||||
# pixel_uvs: (N, H, W, K, 2)
|
||||
pixel_uvs = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs
|
||||
)
|
||||
|
||||
N, H_out, W_out, K = fragments.pix_to_face.shape
|
||||
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
|
||||
|
||||
# pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
|
||||
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
|
||||
|
||||
# textures.map:
|
||||
# (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
|
||||
# -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
|
||||
texture_maps = (
|
||||
texture_maps.permute(0, 3, 1, 2)[None, ...]
|
||||
.expand(K, -1, -1, -1, -1)
|
||||
.transpose(0, 1)
|
||||
.reshape(N * K, C, H_in, W_in)
|
||||
)
|
||||
|
||||
# Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
|
||||
# Now need to format the pixel uvs and the texture map correctly!
|
||||
# From pytorch docs, grid_sample takes `grid` and `input`:
|
||||
# grid specifies the sampling pixel locations normalized by
|
||||
# the input spatial dimensions It should have most
|
||||
# values in the range of [-1, 1]. Values x = -1, y = -1
|
||||
# is the left-top pixel of input, and values x = 1, y = 1 is the
|
||||
# right-bottom pixel of input.
|
||||
|
||||
pixel_uvs = pixel_uvs * 2.0 - 1.0
|
||||
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
|
||||
if texture_maps.device != pixel_uvs.device:
|
||||
texture_maps = texture_maps.to(pixel_uvs.device)
|
||||
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
|
||||
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
|
||||
return texels
|
||||
|
||||
|
||||
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
|
||||
"""
|
||||
Detemine the color for each rasterized face. Interpolate the colors for
|
||||
vertices which form the face using the barycentric coordinates.
|
||||
Args:
|
||||
meshes: A Meshes class representing a batch of meshes.
|
||||
fragments:
|
||||
The outputs of rasterization. From this we use
|
||||
|
||||
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
|
||||
of the faces (in the packed representation) which
|
||||
overlap each pixel in the image.
|
||||
- barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying
|
||||
the barycentric coordianates of each pixel
|
||||
relative to the faces (in the packed
|
||||
representation) which overlap the pixel.
|
||||
|
||||
Returns:
|
||||
texels: An texture per pixel of shape (N, H, W, K, C).
|
||||
There will be one C dimensional value for each element in
|
||||
fragments.pix_to_face.
|
||||
"""
|
||||
vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C)
|
||||
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
|
||||
faces_packed = meshes.faces_packed()
|
||||
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
|
||||
texels = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_textures
|
||||
)
|
||||
return texels
|
@ -2,7 +2,6 @@
|
||||
|
||||
from .meshes import Meshes, join_meshes_as_batch
|
||||
from .pointclouds import Pointclouds
|
||||
from .textures import Textures
|
||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
|
||||
|
||||
|
||||
|
@ -5,7 +5,6 @@ from typing import List, Union
|
||||
import torch
|
||||
|
||||
from . import utils as struct_utils
|
||||
from .textures import Textures
|
||||
|
||||
|
||||
class Meshes(object):
|
||||
@ -234,9 +233,9 @@ class Meshes(object):
|
||||
Refer to comments above for descriptions of List and Padded representations.
|
||||
"""
|
||||
self.device = None
|
||||
if textures is not None and not isinstance(textures, Textures):
|
||||
msg = "Expected textures to be of type Textures; got %r"
|
||||
raise ValueError(msg % type(textures))
|
||||
if textures is not None and not repr(textures) == "TexturesBase":
|
||||
msg = "Expected textures to be an instance of type TexturesBase; got %r"
|
||||
raise ValueError(msg % repr(textures))
|
||||
self.textures = textures
|
||||
|
||||
# Indicates whether the meshes in the list/batch have the same number
|
||||
@ -400,6 +399,8 @@ class Meshes(object):
|
||||
if self.textures is not None:
|
||||
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
||||
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
||||
self.textures._N = self._N
|
||||
self.textures.valid = self.valid
|
||||
|
||||
def __len__(self):
|
||||
return self._N
|
||||
@ -1465,6 +1466,17 @@ class Meshes(object):
|
||||
|
||||
return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
||||
|
||||
def sample_textures(self, fragments):
|
||||
if self.textures is not None:
|
||||
# Pass in faces packed. If the textures are defined per
|
||||
# vertex, the face indices are needed in order to interpolate
|
||||
# the vertex attributes across the face.
|
||||
return self.textures.sample_textures(
|
||||
fragments, faces_packed=self.faces_packed()
|
||||
)
|
||||
else:
|
||||
raise ValueError("Meshes does not have textures")
|
||||
|
||||
|
||||
def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
"""
|
||||
@ -1499,44 +1511,14 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
raise ValueError("Inconsistent textures in join_meshes_as_batch.")
|
||||
|
||||
# Now we know there are multiple meshes and they have textures to merge.
|
||||
first = meshes[0].textures
|
||||
kwargs = {}
|
||||
if first.maps_padded() is not None:
|
||||
if any(mesh.textures.maps_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
|
||||
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
|
||||
kwargs["maps"] = maps
|
||||
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes_as_batch.")
|
||||
all_textures = [mesh.textures for mesh in meshes]
|
||||
first = all_textures[0]
|
||||
tex_types_same = all(type(tex) == type(first) for tex in all_textures)
|
||||
|
||||
if first.verts_uvs_padded() is not None:
|
||||
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
|
||||
V = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
|
||||
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes_as_batch.")
|
||||
if not tex_types_same:
|
||||
raise ValueError("All meshes in the batch must have the same type of texture.")
|
||||
|
||||
if first.faces_uvs_padded() is not None:
|
||||
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
|
||||
F = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
|
||||
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes_as_batch.")
|
||||
|
||||
if first.verts_rgb_padded() is not None:
|
||||
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
|
||||
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
|
||||
V = max(i.shape[0] for i in rgb)
|
||||
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
|
||||
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes_as_batch.")
|
||||
|
||||
tex = Textures(**kwargs)
|
||||
tex = first.join_batch(all_textures[1:])
|
||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
|
||||
@ -1544,7 +1526,7 @@ def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
|
||||
"""
|
||||
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
|
||||
objects as a single mesh. If the input is a list, the Meshes objects in the list
|
||||
must all be on the same device. This version ignores all textures in the input mehses.
|
||||
must all be on the same device. This version ignores all textures in the input meshes.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
|
||||
|
@ -1,279 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from .utils import padded_to_list, padded_to_packed
|
||||
|
||||
|
||||
"""
|
||||
This file has functions for interpolating textures after rasterization.
|
||||
"""
|
||||
|
||||
|
||||
def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Pad all texture images so they have the same height and width.
|
||||
|
||||
Args:
|
||||
images: list of N tensors of shape (H, W, 3)
|
||||
|
||||
Returns:
|
||||
tex_maps: Tensor of shape (N, max_H, max_W, 3)
|
||||
"""
|
||||
tex_maps = []
|
||||
max_H = 0
|
||||
max_W = 0
|
||||
for im in images:
|
||||
h, w, _3 = im.shape
|
||||
if h > max_H:
|
||||
max_H = h
|
||||
if w > max_W:
|
||||
max_W = w
|
||||
tex_maps.append(im)
|
||||
max_shape = (max_H, max_W)
|
||||
|
||||
for i, image in enumerate(tex_maps):
|
||||
if image.shape[:2] != max_shape:
|
||||
image_BCHW = image.permute(2, 0, 1)[None]
|
||||
new_image_BCHW = interpolate(
|
||||
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
|
||||
)
|
||||
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
|
||||
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
|
||||
return tex_maps
|
||||
|
||||
|
||||
def _extend_tensor(input_tensor: torch.Tensor, N: int) -> torch.Tensor:
|
||||
"""
|
||||
Extend a tensor `input_tensor` with ndim > 2, `N` times along the batch
|
||||
dimension. This is done in the following sequence of steps (where `B` is
|
||||
the batch dimension):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input_tensor (B, ...)
|
||||
-> add leading empty dimension (1, B, ...)
|
||||
-> expand (N, B, ...)
|
||||
-> reshape (N * B, ...)
|
||||
|
||||
Args:
|
||||
input_tensor: torch.Tensor with ndim > 2 representing a batched input.
|
||||
N: number of times to extend each element of the batch.
|
||||
"""
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if input_tensor.ndim < 2:
|
||||
raise ValueError("Input tensor must have ndimensions >= 2.")
|
||||
B = input_tensor.shape[0]
|
||||
non_batch_dims = tuple(input_tensor.shape[1:])
|
||||
constant_dims = (-1,) * input_tensor.ndim # these dims are not expanded.
|
||||
return (
|
||||
input_tensor.clone()[None, ...]
|
||||
.expand(N, *constant_dims)
|
||||
.transpose(0, 1)
|
||||
.reshape(N * B, *non_batch_dims)
|
||||
)
|
||||
|
||||
|
||||
class Textures(object):
|
||||
def __init__(
|
||||
self,
|
||||
maps: Union[List, torch.Tensor, None] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
verts_rgb: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
maps: texture map per mesh. This can either be a list of maps
|
||||
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3).
|
||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
|
||||
vertex in the face. Padding value is assumed to be -1.
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
|
||||
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
|
||||
value is assumed to be -1.
|
||||
|
||||
Note: only the padded representation of the textures is stored
|
||||
and the packed/list representations are computed on the fly and
|
||||
not cached.
|
||||
"""
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if faces_uvs is not None and faces_uvs.ndim != 3:
|
||||
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
||||
raise ValueError(msg % repr(faces_uvs.shape))
|
||||
if verts_uvs is not None and verts_uvs.ndim != 3:
|
||||
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
||||
raise ValueError(msg % repr(verts_uvs.shape))
|
||||
if verts_rgb is not None and verts_rgb.ndim != 3:
|
||||
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
|
||||
raise ValueError(msg % repr(verts_rgb.shape))
|
||||
if maps is not None:
|
||||
# pyre-fixme[16]: `List` has no attribute `ndim`.
|
||||
if torch.is_tensor(maps) and maps.ndim != 4:
|
||||
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
|
||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
elif isinstance(maps, list):
|
||||
maps = _pad_texture_maps(maps)
|
||||
if faces_uvs is None or verts_uvs is None:
|
||||
msg = "To use maps, faces_uvs and verts_uvs are required"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._faces_uvs_padded = faces_uvs
|
||||
self._verts_uvs_padded = verts_uvs
|
||||
self._verts_rgb_padded = verts_rgb
|
||||
self._maps_padded = maps
|
||||
|
||||
# The number of faces/verts for each mesh is
|
||||
# set inside the Meshes object when textures is
|
||||
# passed into the Meshes constructor.
|
||||
self._num_faces_per_mesh = None
|
||||
self._num_verts_per_mesh = None
|
||||
|
||||
def clone(self):
|
||||
other = self.__class__()
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.clone())
|
||||
return other
|
||||
|
||||
def to(self, device):
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v) and v.device != device:
|
||||
setattr(self, k, v.to(device))
|
||||
return self
|
||||
|
||||
def __getitem__(self, index):
|
||||
other = self.__class__()
|
||||
for key in dir(self):
|
||||
value = getattr(self, key)
|
||||
if torch.is_tensor(value):
|
||||
if isinstance(index, int):
|
||||
setattr(other, key, value[index][None])
|
||||
else:
|
||||
setattr(other, key, value[index])
|
||||
return other
|
||||
|
||||
def faces_uvs_padded(self) -> torch.Tensor:
|
||||
# pyre-fixme[7]: Expected `Tensor` but got `Optional[torch.Tensor]`.
|
||||
return self._faces_uvs_padded
|
||||
|
||||
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._faces_uvs_padded,
|
||||
split_size=self._num_faces_per_mesh,
|
||||
)
|
||||
|
||||
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._faces_uvs_padded,
|
||||
split_size=self._num_faces_per_mesh,
|
||||
)
|
||||
|
||||
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_uvs_padded
|
||||
|
||||
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
return padded_to_list(self._verts_uvs_padded)
|
||||
|
||||
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
return padded_to_packed(self._verts_uvs_padded)
|
||||
|
||||
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_rgb_padded
|
||||
|
||||
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._verts_rgb_padded,
|
||||
split_size=self._num_verts_per_mesh,
|
||||
)
|
||||
|
||||
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._verts_rgb_padded,
|
||||
split_size=self._num_verts_per_mesh,
|
||||
)
|
||||
|
||||
# Currently only the padded maps are used.
|
||||
def maps_padded(self) -> Union[torch.Tensor, None]:
|
||||
# pyre-fixme[7]: Expected `Optional[torch.Tensor]` but got `Union[None,
|
||||
# List[typing.Any], torch.Tensor]`.
|
||||
return self._maps_padded
|
||||
|
||||
def extend(self, N: int) -> "Textures":
|
||||
"""
|
||||
Create new Textures class which contains each input texture N times
|
||||
|
||||
Args:
|
||||
N: number of new copies of each texture.
|
||||
|
||||
Returns:
|
||||
new Textures object.
|
||||
"""
|
||||
if not isinstance(N, int):
|
||||
raise ValueError("N must be an integer.")
|
||||
if N <= 0:
|
||||
raise ValueError("N must be > 0.")
|
||||
|
||||
if all(
|
||||
v is not None
|
||||
for v in [self._faces_uvs_padded, self._verts_uvs_padded, self._maps_padded]
|
||||
):
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N)
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N)
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[None,
|
||||
# List[typing.Any], torch.Tensor]`.
|
||||
new_maps = _extend_tensor(self._maps_padded, N)
|
||||
return self.__class__(
|
||||
verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps
|
||||
)
|
||||
elif self._verts_rgb_padded is not None:
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N)
|
||||
return self.__class__(verts_rgb=new_verts_rgb)
|
||||
else:
|
||||
msg = "Either vertex colors or texture maps are required."
|
||||
raise ValueError(msg)
|
@ -73,6 +73,7 @@ def padded_to_list(x: torch.Tensor, split_size: Union[list, tuple, None] = None)
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
if x.ndim != 3:
|
||||
raise ValueError("Supports only 3-dimensional input tensors")
|
||||
|
||||
x_list = list(x.unbind(0))
|
||||
|
||||
if split_size is None:
|
||||
|
BIN
tests/data/test_texture_atlas_8x8_back.png
Normal file
BIN
tests/data/test_texture_atlas_8x8_back.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 31 KiB |
@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import (
|
||||
interpolate_face_attributes,
|
||||
interpolate_face_attributes_python,
|
||||
)
|
||||
from pytorch3d.renderer.mesh import TexturesVertex
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors
|
||||
from pytorch3d.structures import Meshes, Textures
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
|
||||
|
||||
def test_interpolate_attributes(self):
|
||||
"""
|
||||
This tests both interpolate_vertex_colors as well as
|
||||
interpolate_face_attributes.
|
||||
"""
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
|
||||
)
|
||||
tex = Textures(verts_rgb=vert_tex[None, :])
|
||||
tex = TexturesVertex(verts_features=vert_tex[None, :])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
texels = interpolate_vertex_colors(fragments, mesh)
|
||||
|
||||
verts_features_packed = mesh.textures.verts_features_packed()
|
||||
faces_verts_features = verts_features_packed[mesh.faces_packed()]
|
||||
|
||||
texels = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
|
||||
)
|
||||
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
|
||||
|
||||
def test_interpolate_attributes_grad(self):
|
||||
@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
tex = Textures(verts_rgb=vert_tex[None, :])
|
||||
tex = TexturesVertex(verts_features=vert_tex[None, :])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
|
||||
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
texels = interpolate_vertex_colors(fragments, mesh)
|
||||
verts_features_packed = mesh.textures.verts_features_packed()
|
||||
faces_verts_features = verts_features_packed[mesh.faces_packed()]
|
||||
|
||||
texels = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
|
||||
)
|
||||
texels.sum().backward()
|
||||
self.assertTrue(hasattr(vert_tex, "grad"))
|
||||
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
|
||||
|
@ -13,8 +13,8 @@ from pytorch3d.io.mtl_io import (
|
||||
_bilinear_interpolation_grid_sample,
|
||||
_bilinear_interpolation_vectorized,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
from pytorch3d.structures.meshes import join_mesh
|
||||
from pytorch3d.renderer import TexturesAtlas, TexturesUV, TexturesVertex
|
||||
from pytorch3d.structures import Meshes, join_meshes_as_batch
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
|
||||
@ -590,17 +590,29 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
check_item(mesh.verts_padded(), mesh3.verts_padded())
|
||||
check_item(mesh.faces_padded(), mesh3.faces_padded())
|
||||
|
||||
if mesh.textures is not None:
|
||||
check_item(mesh.textures.maps_padded(), mesh3.textures.maps_padded())
|
||||
check_item(
|
||||
mesh.textures.faces_uvs_padded(), mesh3.textures.faces_uvs_padded()
|
||||
)
|
||||
check_item(
|
||||
mesh.textures.verts_uvs_padded(), mesh3.textures.verts_uvs_padded()
|
||||
)
|
||||
check_item(
|
||||
mesh.textures.verts_rgb_padded(), mesh3.textures.verts_rgb_padded()
|
||||
)
|
||||
if isinstance(mesh.textures, TexturesUV):
|
||||
check_item(
|
||||
mesh.textures.faces_uvs_padded(),
|
||||
mesh3.textures.faces_uvs_padded(),
|
||||
)
|
||||
check_item(
|
||||
mesh.textures.verts_uvs_padded(),
|
||||
mesh3.textures.verts_uvs_padded(),
|
||||
)
|
||||
check_item(
|
||||
mesh.textures.maps_padded(), mesh3.textures.maps_padded()
|
||||
)
|
||||
elif isinstance(mesh.textures, TexturesVertex):
|
||||
check_item(
|
||||
mesh.textures.verts_features_padded(),
|
||||
mesh3.textures.verts_features_padded(),
|
||||
)
|
||||
elif isinstance(mesh.textures, TexturesAtlas):
|
||||
check_item(
|
||||
mesh.textures.atlas_padded(), mesh3.textures.atlas_padded()
|
||||
)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
|
||||
@ -623,16 +635,24 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
check_triple(mesh_notex, mesh3_notex)
|
||||
self.assertIsNone(mesh_notex.textures)
|
||||
|
||||
# meshes with vertex texture, join into a batch.
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
|
||||
)
|
||||
tex = Textures(verts_rgb=vert_tex[None, :])
|
||||
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
vert_tex = torch.ones_like(verts)
|
||||
rgb_tex = TexturesVertex(verts_features=[vert_tex])
|
||||
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
|
||||
mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
|
||||
check_triple(mesh_rgb, mesh_rgb3)
|
||||
|
||||
# meshes with texture atlas, join into a batch.
|
||||
device = "cuda:0"
|
||||
atlas = torch.rand((2, 4, 4, 3), dtype=torch.float32, device=device)
|
||||
atlas_tex = TexturesAtlas(atlas=[atlas])
|
||||
mesh_atlas = Meshes(verts=[verts], faces=[faces], textures=atlas_tex)
|
||||
mesh_atlas3 = join_meshes_as_batch([mesh_atlas, mesh_atlas, mesh_atlas])
|
||||
check_triple(mesh_atlas, mesh_atlas3)
|
||||
|
||||
# Test load multiple meshes with textures into a batch.
|
||||
teapot_obj = DATA_DIR / "teapot.obj"
|
||||
mesh_teapot = load_objs_as_meshes([teapot_obj])
|
||||
teapot_verts, teapot_faces = mesh_teapot.get_mesh_verts_faces(0)
|
||||
@ -649,41 +669,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
|
||||
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
|
||||
|
||||
def test_join_meshes(self):
|
||||
"""
|
||||
Test that join_mesh joins single meshes and the corresponding values are
|
||||
consistent with the single meshes.
|
||||
"""
|
||||
|
||||
# Load cow mesh.
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
cow_obj = DATA_DIR / "cow_mesh/cow.obj"
|
||||
|
||||
cow_mesh = load_objs_as_meshes([cow_obj])
|
||||
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
|
||||
# Join a batch of three single meshes and check that the values are consistent
|
||||
# with the individual meshes.
|
||||
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
|
||||
|
||||
def check_item(x, y, offset):
|
||||
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
|
||||
|
||||
check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
|
||||
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
|
||||
|
||||
# Test the joining of meshes of different sizes.
|
||||
teapot_obj = DATA_DIR / "teapot.obj"
|
||||
teapot_mesh = load_objs_as_meshes([teapot_obj])
|
||||
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
|
||||
|
||||
mix_mesh = join_mesh([cow_mesh, teapot_mesh])
|
||||
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
|
||||
self.assertEqual(len(mix_mesh), 1)
|
||||
|
||||
self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
|
||||
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
|
||||
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
|
||||
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
|
||||
# Check error raised if all meshes in the batch don't have the same texture type
|
||||
with self.assertRaisesRegex(ValueError, "same type of texture"):
|
||||
join_meshes_as_batch([mesh_atlas, mesh_rgb, mesh_atlas])
|
||||
|
||||
@staticmethod
|
||||
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
|
@ -11,10 +11,11 @@ import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.materials import Materials
|
||||
from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import (
|
||||
@ -25,7 +26,6 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.texturing import Textures
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
@ -52,7 +52,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
@ -97,6 +98,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
if DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
@ -145,14 +147,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Test a mesh with vertex textures can be extended to form a batch, and
|
||||
is rendered correctly with Phong, Gouraud and Flat Shaders.
|
||||
"""
|
||||
batch_size = 20
|
||||
batch_size = 5
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh with vertex textures.
|
||||
sphere_meshes = ico_sphere(5, device).extend(batch_size)
|
||||
verts_padded = sphere_meshes.verts_padded()
|
||||
faces_padded = sphere_meshes.faces_padded()
|
||||
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
||||
feats = torch.ones_like(verts_padded, device=device)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_meshes = Meshes(
|
||||
verts=verts_padded, faces=faces_padded, textures=textures
|
||||
)
|
||||
@ -194,6 +197,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
for i in range(batch_size):
|
||||
rgb = images[i, ..., :3].squeeze().cpu()
|
||||
if i == 0 and DEBUG:
|
||||
filename = "DEBUG_simple_sphere_batched_%s.png" % name
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_silhouette_with_grad(self):
|
||||
@ -233,6 +241,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
with Image.open(image_ref_filename) as raw_image_ref:
|
||||
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
||||
|
||||
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
||||
self.assertClose(alpha, image_ref, atol=0.055)
|
||||
|
||||
@ -253,11 +262,20 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh + texture
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
verts, faces, aux = load_obj(
|
||||
obj_filename, device=device, load_textures=True, texture_wrap=None
|
||||
)
|
||||
tex_map = list(aux.texture_images.values())[0]
|
||||
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
|
||||
textures = TexturesUV(
|
||||
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
|
||||
)
|
||||
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
@ -405,8 +423,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
|
||||
)
|
||||
joined_sphere_mesh = join_mesh(sphere_mesh_list)
|
||||
joined_sphere_mesh.textures = Textures(
|
||||
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
joined_sphere_mesh.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
@ -446,3 +464,61 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_texture_map_atlas(self):
|
||||
"""
|
||||
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh and texture as a per face texture atlas.
|
||||
verts, faces, aux = load_obj(
|
||||
obj_filename,
|
||||
device=device,
|
||||
load_textures=True,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=8,
|
||||
texture_wrap=None,
|
||||
)
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces.verts_idx],
|
||||
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device, specular_color=((0, 0, 0),), shininess=0.0)
|
||||
lights = PointLights(device=device)
|
||||
|
||||
# Place light behind the cow in world space. The front of
|
||||
# the cow is facing the -z direction.
|
||||
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||
|
||||
# The HardPhongShader can be used directly with atlas textures.
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
|
||||
)
|
||||
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
@ -7,14 +7,376 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.texturing import interpolate_texture_map
|
||||
from pytorch3d.structures import Meshes, Textures
|
||||
from pytorch3d.structures.utils import list_to_padded
|
||||
from pytorch3d.renderer.mesh.textures import (
|
||||
TexturesAtlas,
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
_list_to_padded_wrapper,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
|
||||
from test_meshes import TestMeshes
|
||||
|
||||
|
||||
class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
def test_interpolate_texture_map(self):
|
||||
def tryindex(self, index, tex, meshes, source):
|
||||
tex2 = tex[index]
|
||||
meshes2 = meshes[index]
|
||||
tex_from_meshes = meshes2.textures
|
||||
for item in source:
|
||||
basic = source[item][index]
|
||||
from_texture = getattr(tex2, item + "_padded")()
|
||||
from_meshes = getattr(tex_from_meshes, item + "_padded")()
|
||||
if isinstance(index, int):
|
||||
basic = basic[None]
|
||||
|
||||
if len(basic) == 0:
|
||||
self.assertEquals(len(from_texture), 0)
|
||||
self.assertEquals(len(from_meshes), 0)
|
||||
else:
|
||||
self.assertClose(basic, from_texture)
|
||||
self.assertClose(basic, from_meshes)
|
||||
self.assertEqual(from_texture.ndim, getattr(tex, item + "_padded")().ndim)
|
||||
item_list = getattr(tex_from_meshes, item + "_list")()
|
||||
self.assertEqual(basic.shape[0], len(item_list))
|
||||
for i, elem in enumerate(item_list):
|
||||
self.assertClose(elem, basic[i])
|
||||
|
||||
|
||||
class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
|
||||
def test_sample_vertex_textures(self):
|
||||
"""
|
||||
This tests both interpolate_vertex_colors as well as
|
||||
interpolate_face_attributes.
|
||||
"""
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
|
||||
)
|
||||
verts_features = vert_tex
|
||||
tex = TexturesVertex(verts_features=[verts_features])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
expected_vals = torch.tensor(
|
||||
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
# sample_textures calls interpolate_vertex_colors
|
||||
texels = mesh.sample_textures(fragments)
|
||||
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
|
||||
|
||||
def test_sample_vertex_textures_grad(self):
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
vert_tex = torch.tensor(
|
||||
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]],
|
||||
dtype=torch.float32,
|
||||
requires_grad=True,
|
||||
)
|
||||
verts_features = vert_tex
|
||||
tex = TexturesVertex(verts_features=[verts_features])
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
grad_vert_tex = torch.tensor(
|
||||
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
texels = mesh.sample_textures(fragments)
|
||||
texels.sum().backward()
|
||||
self.assertTrue(hasattr(vert_tex, "grad"))
|
||||
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))
|
||||
|
||||
def test_textures_vertex_init_fail(self):
|
||||
# Incorrect sized tensors
|
||||
with self.assertRaisesRegex(ValueError, "verts_features"):
|
||||
TexturesVertex(verts_features=torch.rand(size=(5, 10)))
|
||||
|
||||
# Not a list or a tensor
|
||||
with self.assertRaisesRegex(ValueError, "verts_features"):
|
||||
TexturesVertex(verts_features=(1, 1, 1))
|
||||
|
||||
def test_clone(self):
|
||||
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
|
||||
tex_cloned = tex.clone()
|
||||
self.assertSeparate(
|
||||
tex._verts_features_padded, tex_cloned._verts_features_padded
|
||||
)
|
||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||
|
||||
def test_extend(self):
|
||||
B = 10
|
||||
mesh = TestMeshes.init_mesh(B, 30, 50)
|
||||
V = mesh._V
|
||||
tex_uv = TexturesVertex(verts_features=torch.randn((B, V, 3)))
|
||||
tex_mesh = Meshes(
|
||||
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
|
||||
)
|
||||
N = 20
|
||||
new_mesh = tex_mesh.extend(N)
|
||||
|
||||
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
|
||||
|
||||
tex_init = tex_mesh.textures
|
||||
new_tex = new_mesh.textures
|
||||
|
||||
for i in range(len(tex_mesh)):
|
||||
for n in range(N):
|
||||
self.assertClose(
|
||||
tex_init.verts_features_list()[i],
|
||||
new_tex.verts_features_list()[i * N + n],
|
||||
)
|
||||
self.assertClose(
|
||||
tex_init._num_faces_per_mesh[i],
|
||||
new_tex._num_faces_per_mesh[i * N + n],
|
||||
)
|
||||
|
||||
self.assertAllSeparate(
|
||||
[tex_init.verts_features_padded(), new_tex.verts_features_padded()]
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tex_mesh.extend(N=-1)
|
||||
|
||||
def test_padded_to_packed(self):
|
||||
# Case where each face in the mesh has 3 unique uv vertex indices
|
||||
# - i.e. even if a vertex is shared between multiple faces it will
|
||||
# have a unique uv coordinate for each face.
|
||||
num_verts_per_mesh = [9, 6]
|
||||
D = 10
|
||||
verts_features_list = [torch.rand(v, D) for v in num_verts_per_mesh]
|
||||
verts_features_packed = list_to_packed(verts_features_list)[0]
|
||||
verts_features_list = packed_to_list(verts_features_packed, num_verts_per_mesh)
|
||||
tex = TexturesVertex(verts_features=verts_features_list)
|
||||
|
||||
# This is set inside Meshes when textures is passed as an input.
|
||||
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
|
||||
tex1 = tex.clone()
|
||||
tex1._num_verts_per_mesh = num_verts_per_mesh
|
||||
verts_packed = tex1.verts_features_packed()
|
||||
verts_verts_list = tex1.verts_features_list()
|
||||
verts_padded = tex1.verts_features_padded()
|
||||
|
||||
for f1, f2 in zip(verts_verts_list, verts_features_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
self.assertTrue(verts_packed.shape == (sum(num_verts_per_mesh), D))
|
||||
self.assertTrue(verts_padded.shape == (2, 9, D))
|
||||
|
||||
# Case where num_verts_per_mesh is not set and textures
|
||||
# are initialized with a padded tensor.
|
||||
tex2 = TexturesVertex(verts_features=verts_padded)
|
||||
verts_packed = tex2.verts_features_packed()
|
||||
verts_list = tex2.verts_features_list()
|
||||
|
||||
# Packed is just flattened padded as num_verts_per_mesh
|
||||
# has not been provided.
|
||||
self.assertTrue(verts_packed.shape == (9 * 2, D))
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(verts_list, verts_features_list)):
|
||||
n = num_verts_per_mesh[i]
|
||||
self.assertTrue((f1[:n] == f2).all().item())
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
V = 20
|
||||
source = {"verts_features": torch.randn(size=(N, 10, 128))}
|
||||
tex = TexturesVertex(verts_features=source["verts_features"])
|
||||
|
||||
verts = torch.rand(size=(N, V, 3))
|
||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
tryindex(self, 2, tex, meshes, source)
|
||||
tryindex(self, slice(0, 2, 1), tex, meshes, source)
|
||||
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
tryindex(self, [2, 4], tex, meshes, source)
|
||||
|
||||
|
||||
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
||||
def test_sample_texture_atlas(self):
|
||||
N, F, R = 1, 2, 2
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
faces_atlas = torch.rand(size=(N, F, R, R, 3))
|
||||
tex = TexturesAtlas(atlas=faces_atlas)
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
expected_vals = torch.tensor(
|
||||
[[0.5, 1.0, 0.3], [0.3, 1.0, 0.9]], dtype=torch.float32
|
||||
)
|
||||
expected_vals = torch.zeros((1, 1, 1, 2, 3), dtype=torch.float32)
|
||||
expected_vals[..., 0, :] = faces_atlas[0, 0, 0, 1, ...]
|
||||
expected_vals[..., 1, :] = faces_atlas[0, 1, 1, 0, ...]
|
||||
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
texels = mesh.textures.sample_textures(fragments)
|
||||
self.assertTrue(torch.allclose(texels, expected_vals))
|
||||
|
||||
def test_textures_atlas_grad(self):
|
||||
N, F, R = 1, 2, 2
|
||||
verts = torch.randn((4, 3), dtype=torch.float32)
|
||||
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
|
||||
faces_atlas = torch.rand(size=(N, F, R, R, 3), requires_grad=True)
|
||||
tex = TexturesAtlas(atlas=faces_atlas)
|
||||
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
|
||||
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
bary_coords=barycentric_coords,
|
||||
zbuf=torch.ones_like(pix_to_face),
|
||||
dists=torch.ones_like(pix_to_face),
|
||||
)
|
||||
texels = mesh.textures.sample_textures(fragments)
|
||||
grad_tex = torch.rand_like(texels)
|
||||
grad_expected = torch.zeros_like(faces_atlas)
|
||||
grad_expected[0, 0, 0, 1, :] = grad_tex[..., 0:1, :]
|
||||
grad_expected[0, 1, 1, 0, :] = grad_tex[..., 1:2, :]
|
||||
texels.backward(grad_tex)
|
||||
self.assertTrue(hasattr(faces_atlas, "grad"))
|
||||
self.assertTrue(torch.allclose(faces_atlas.grad, grad_expected))
|
||||
|
||||
def test_textures_atlas_init_fail(self):
|
||||
# Incorrect sized tensors
|
||||
with self.assertRaisesRegex(ValueError, "atlas"):
|
||||
TexturesAtlas(atlas=torch.rand(size=(5, 10, 3)))
|
||||
|
||||
# Not a list or a tensor
|
||||
with self.assertRaisesRegex(ValueError, "atlas"):
|
||||
TexturesAtlas(atlas=(1, 1, 1))
|
||||
|
||||
def test_clone(self):
|
||||
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
|
||||
tex_cloned = tex.clone()
|
||||
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
|
||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||
|
||||
def test_extend(self):
|
||||
B = 10
|
||||
mesh = TestMeshes.init_mesh(B, 30, 50)
|
||||
F = mesh._F
|
||||
tex_uv = TexturesAtlas(atlas=torch.randn((B, F, 2, 2, 3)))
|
||||
tex_mesh = Meshes(
|
||||
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
|
||||
)
|
||||
N = 20
|
||||
new_mesh = tex_mesh.extend(N)
|
||||
|
||||
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
|
||||
|
||||
tex_init = tex_mesh.textures
|
||||
new_tex = new_mesh.textures
|
||||
|
||||
for i in range(len(tex_mesh)):
|
||||
for n in range(N):
|
||||
self.assertClose(
|
||||
tex_init.atlas_list()[i], new_tex.atlas_list()[i * N + n]
|
||||
)
|
||||
self.assertClose(
|
||||
tex_init._num_faces_per_mesh[i],
|
||||
new_tex._num_faces_per_mesh[i * N + n],
|
||||
)
|
||||
|
||||
self.assertAllSeparate([tex_init.atlas_padded(), new_tex.atlas_padded()])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tex_mesh.extend(N=-1)
|
||||
|
||||
def test_padded_to_packed(self):
|
||||
# Case where each face in the mesh has 3 unique uv vertex indices
|
||||
# - i.e. even if a vertex is shared between multiple faces it will
|
||||
# have a unique uv coordinate for each face.
|
||||
R = 2
|
||||
N = 20
|
||||
num_faces_per_mesh = torch.randint(size=(N,), low=0, high=30)
|
||||
atlas_list = [torch.rand(f, R, R, 3) for f in num_faces_per_mesh]
|
||||
tex = TexturesAtlas(atlas=atlas_list)
|
||||
|
||||
# This is set inside Meshes when textures is passed as an input.
|
||||
# Here we set _num_faces_per_mesh explicity.
|
||||
tex1 = tex.clone()
|
||||
tex1._num_faces_per_mesh = num_faces_per_mesh.tolist()
|
||||
atlas_packed = tex1.atlas_packed()
|
||||
atlas_list_new = tex1.atlas_list()
|
||||
atlas_padded = tex1.atlas_padded()
|
||||
|
||||
for f1, f2 in zip(atlas_list_new, atlas_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
sum_F = num_faces_per_mesh.sum()
|
||||
max_F = num_faces_per_mesh.max().item()
|
||||
self.assertTrue(atlas_packed.shape == (sum_F, R, R, 3))
|
||||
self.assertTrue(atlas_padded.shape == (N, max_F, R, R, 3))
|
||||
|
||||
# Case where num_faces_per_mesh is not set and textures
|
||||
# are initialized with a padded tensor.
|
||||
atlas_list_padded = _list_to_padded_wrapper(atlas_list)
|
||||
tex2 = TexturesAtlas(atlas=atlas_list_padded)
|
||||
atlas_packed = tex2.atlas_packed()
|
||||
atlas_list_new = tex2.atlas_list()
|
||||
|
||||
# Packed is just flattened padded as num_faces_per_mesh
|
||||
# has not been provided.
|
||||
self.assertTrue(atlas_packed.shape == (N * max_F, R, R, 3))
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(atlas_list_new, atlas_list)):
|
||||
n = num_faces_per_mesh[i]
|
||||
self.assertTrue((f1[:n] == f2).all().item())
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
V = 20
|
||||
source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
|
||||
tex = TexturesAtlas(atlas=source["atlas"])
|
||||
|
||||
verts = torch.rand(size=(N, V, 3))
|
||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
tryindex(self, 2, tex, meshes, source)
|
||||
tryindex(self, slice(0, 2, 1), tex, meshes, source)
|
||||
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
tryindex(self, [2, 4], tex, meshes, source)
|
||||
|
||||
|
||||
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
def test_sample_textures_uv(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
).view(1, 1, 1, 2, -1)
|
||||
@ -38,11 +400,11 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
zbuf=pix_to_face,
|
||||
dists=pix_to_face,
|
||||
)
|
||||
tex = Textures(
|
||||
maps=tex_map, faces_uvs=face_uvs[None, ...], verts_uvs=vert_uvs[None, ...]
|
||||
)
|
||||
|
||||
tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
|
||||
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
|
||||
texels = interpolate_texture_map(fragments, meshes)
|
||||
mesh_textures = meshes.textures
|
||||
texels = mesh_textures.sample_textures(fragments)
|
||||
|
||||
# Expected output
|
||||
pixel_uvs = interpolated_uvs * 2.0 - 1.0
|
||||
@ -53,190 +415,92 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
|
||||
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
|
||||
|
||||
def test_init_rgb_uv_fail(self):
|
||||
V = 20
|
||||
def test_textures_uv_init_fail(self):
|
||||
# Maps has wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "maps"):
|
||||
Textures(
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3, 4)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||
verts_uvs=torch.ones((5, V, 2)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
|
||||
# faces_uvs has wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "faces_uvs"):
|
||||
Textures(
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
|
||||
verts_uvs=torch.ones((5, V, 2)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
|
||||
# verts_uvs has wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "verts_uvs"):
|
||||
Textures(
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||
verts_uvs=torch.ones((5, V, 2, 3)),
|
||||
)
|
||||
# verts_rgb has wrong shape
|
||||
with self.assertRaisesRegex(ValueError, "verts_rgb"):
|
||||
Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
|
||||
|
||||
# maps provided without verts/faces uvs
|
||||
with self.assertRaisesRegex(ValueError, "faces_uvs and verts_uvs are required"):
|
||||
Textures(maps=torch.ones((5, 16, 16, 3)))
|
||||
|
||||
def test_padded_to_packed(self):
|
||||
N = 2
|
||||
# Case where each face in the mesh has 3 unique uv vertex indices
|
||||
# - i.e. even if a vertex is shared between multiple faces it will
|
||||
# have a unique uv coordinate for each face.
|
||||
faces_uvs_list = [
|
||||
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
|
||||
torch.tensor([[0, 1, 2], [3, 4, 5]]),
|
||||
] # (N, 3, 3)
|
||||
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
|
||||
faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
|
||||
verts_uvs_padded = list_to_padded(verts_uvs_list)
|
||||
tex = Textures(
|
||||
maps=torch.ones((N, 16, 16, 3)),
|
||||
faces_uvs=faces_uvs_padded,
|
||||
verts_uvs=verts_uvs_padded,
|
||||
)
|
||||
|
||||
# This is set inside Meshes when textures is passed as an input.
|
||||
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
|
||||
tex1 = tex.clone()
|
||||
tex1._num_faces_per_mesh = faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
|
||||
tex1._num_verts_per_mesh = torch.tensor([5, 4])
|
||||
faces_packed = tex1.faces_uvs_packed()
|
||||
verts_packed = tex1.verts_uvs_packed()
|
||||
faces_list = tex1.faces_uvs_list()
|
||||
verts_list = tex1.verts_uvs_list()
|
||||
|
||||
for f1, f2 in zip(faces_uvs_list, faces_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
|
||||
idx = f.unique()
|
||||
self.assertTrue((v1[idx] == v2).all().item())
|
||||
|
||||
self.assertTrue(faces_packed.shape == (3 + 2, 3))
|
||||
|
||||
# verts_packed is just flattened verts_padded.
|
||||
# split sizes are not used for verts_uvs.
|
||||
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||
|
||||
# Case where num_faces_per_mesh is not set
|
||||
tex2 = tex.clone()
|
||||
faces_packed = tex2.faces_uvs_packed()
|
||||
verts_packed = tex2.verts_uvs_packed()
|
||||
faces_list = tex2.faces_uvs_list()
|
||||
verts_list = tex2.verts_uvs_list()
|
||||
|
||||
# Packed is just flattened padded as num_faces_per_mesh
|
||||
# has not been provided.
|
||||
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||
self.assertTrue(faces_packed.shape == (3 * 2, 3))
|
||||
|
||||
for i in range(N):
|
||||
self.assertTrue(
|
||||
(faces_list[i] == faces_uvs_padded[i, ...].squeeze()).all().item()
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2, 3)),
|
||||
)
|
||||
|
||||
for i in range(N):
|
||||
self.assertTrue(
|
||||
(verts_list[i] == verts_uvs_padded[i, ...].squeeze()).all().item()
|
||||
# verts has different batch dim to faces
|
||||
with self.assertRaisesRegex(ValueError, "verts_uvs"):
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(8, 15, 2)),
|
||||
)
|
||||
|
||||
# maps has different batch dim to faces
|
||||
with self.assertRaisesRegex(ValueError, "maps"):
|
||||
TexturesUV(
|
||||
maps=torch.ones((8, 16, 16, 3)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
|
||||
# verts on different device to faces
|
||||
with self.assertRaisesRegex(ValueError, "verts_uvs"):
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2, 3), device="cuda"),
|
||||
)
|
||||
|
||||
# maps on different device to faces
|
||||
with self.assertRaisesRegex(ValueError, "map"):
|
||||
TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3), device="cuda"),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
|
||||
def test_clone(self):
|
||||
V = 20
|
||||
tex = Textures(
|
||||
tex = TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||
verts_uvs=torch.ones((5, V, 2)),
|
||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
tex_cloned = tex.clone()
|
||||
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
|
||||
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
|
||||
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
V = 20
|
||||
source = {
|
||||
"maps": torch.rand(size=(N, 16, 16, 3)),
|
||||
"faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
|
||||
"verts_uvs": torch.rand((N, V, 2)),
|
||||
}
|
||||
tex = Textures(
|
||||
maps=source["maps"],
|
||||
faces_uvs=source["faces_uvs"],
|
||||
verts_uvs=source["verts_uvs"],
|
||||
)
|
||||
|
||||
verts = torch.rand(size=(N, V, 3))
|
||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
def tryindex(index):
|
||||
tex2 = tex[index]
|
||||
meshes2 = meshes[index]
|
||||
tex_from_meshes = meshes2.textures
|
||||
for item in source:
|
||||
basic = source[item][index]
|
||||
from_texture = getattr(tex2, item + "_padded")()
|
||||
from_meshes = getattr(tex_from_meshes, item + "_padded")()
|
||||
if isinstance(index, int):
|
||||
basic = basic[None]
|
||||
self.assertClose(basic, from_texture)
|
||||
self.assertClose(basic, from_meshes)
|
||||
self.assertEqual(
|
||||
from_texture.ndim, getattr(tex, item + "_padded")().ndim
|
||||
)
|
||||
if item == "faces_uvs":
|
||||
faces_uvs_list = tex_from_meshes.faces_uvs_list()
|
||||
self.assertEqual(basic.shape[0], len(faces_uvs_list))
|
||||
for i, faces_uvs in enumerate(faces_uvs_list):
|
||||
self.assertClose(faces_uvs, basic[i])
|
||||
|
||||
tryindex(2)
|
||||
tryindex(slice(0, 2, 1))
|
||||
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
|
||||
tryindex(index)
|
||||
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
|
||||
tryindex(index)
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(index)
|
||||
tryindex([2, 4])
|
||||
|
||||
def test_to(self):
|
||||
V = 20
|
||||
tex = Textures(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||
verts_uvs=torch.ones((5, V, 2)),
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
tex = tex.to(device)
|
||||
self.assertTrue(tex._faces_uvs_padded.device == device)
|
||||
self.assertTrue(tex._verts_uvs_padded.device == device)
|
||||
self.assertTrue(tex._maps_padded.device == device)
|
||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||
|
||||
def test_extend(self):
|
||||
B = 10
|
||||
B = 5
|
||||
mesh = TestMeshes.init_mesh(B, 30, 50)
|
||||
V = mesh._V
|
||||
F = mesh._F
|
||||
|
||||
# 1. Texture uvs
|
||||
tex_uv = Textures(
|
||||
maps=torch.randn((B, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
|
||||
verts_uvs=torch.randn((B, V, 2)),
|
||||
num_faces = mesh.num_faces_per_mesh()
|
||||
num_verts = mesh.num_verts_per_mesh()
|
||||
faces_uvs_list = [torch.randint(size=(f, 3), low=0, high=V) for f in num_faces]
|
||||
verts_uvs_list = [torch.rand(v, 2) for v in num_verts]
|
||||
tex_uv = TexturesUV(
|
||||
maps=torch.ones((B, 16, 16, 3)),
|
||||
faces_uvs=faces_uvs_list,
|
||||
verts_uvs=verts_uvs_list,
|
||||
)
|
||||
tex_mesh = Meshes(
|
||||
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_uv
|
||||
verts=mesh.verts_list(), faces=mesh.faces_list(), textures=tex_uv
|
||||
)
|
||||
N = 20
|
||||
N = 2
|
||||
new_mesh = tex_mesh.extend(N)
|
||||
|
||||
self.assertEqual(len(tex_mesh) * N, len(new_mesh))
|
||||
@ -246,56 +510,142 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
for i in range(len(tex_mesh)):
|
||||
for n in range(N):
|
||||
self.assertClose(
|
||||
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
|
||||
)
|
||||
self.assertClose(
|
||||
tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]
|
||||
)
|
||||
self.assertClose(
|
||||
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n]
|
||||
tex_init.maps_padded()[i, ...], new_tex.maps_padded()[i * N + n]
|
||||
)
|
||||
self.assertClose(
|
||||
tex_init._num_faces_per_mesh[i],
|
||||
new_tex._num_faces_per_mesh[i * N + n],
|
||||
)
|
||||
|
||||
self.assertAllSeparate(
|
||||
[
|
||||
tex_init.faces_uvs_padded(),
|
||||
new_tex.faces_uvs_padded(),
|
||||
tex_init.faces_uvs_packed(),
|
||||
new_tex.faces_uvs_packed(),
|
||||
tex_init.verts_uvs_padded(),
|
||||
new_tex.verts_uvs_padded(),
|
||||
tex_init.verts_uvs_packed(),
|
||||
new_tex.verts_uvs_packed(),
|
||||
tex_init.maps_padded(),
|
||||
new_tex.maps_padded(),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertIsNone(new_tex.verts_rgb_list())
|
||||
self.assertIsNone(new_tex.verts_rgb_padded())
|
||||
self.assertIsNone(new_tex.verts_rgb_packed())
|
||||
|
||||
# 2. Texture vertex RGB
|
||||
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
|
||||
tex_mesh_rgb = Meshes(
|
||||
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex_rgb
|
||||
)
|
||||
N = 20
|
||||
new_mesh_rgb = tex_mesh_rgb.extend(N)
|
||||
|
||||
self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
|
||||
|
||||
tex_init = tex_mesh_rgb.textures
|
||||
new_tex = new_mesh_rgb.textures
|
||||
|
||||
for i in range(len(tex_mesh_rgb)):
|
||||
for n in range(N):
|
||||
self.assertClose(
|
||||
tex_init.verts_rgb_list()[i], new_tex.verts_rgb_list()[i * N + n]
|
||||
)
|
||||
self.assertAllSeparate(
|
||||
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
|
||||
)
|
||||
|
||||
self.assertIsNone(new_tex.verts_uvs_padded())
|
||||
self.assertIsNone(new_tex.verts_uvs_list())
|
||||
self.assertIsNone(new_tex.verts_uvs_packed())
|
||||
self.assertIsNone(new_tex.faces_uvs_padded())
|
||||
self.assertIsNone(new_tex.faces_uvs_list())
|
||||
self.assertIsNone(new_tex.faces_uvs_packed())
|
||||
|
||||
# 3. Error
|
||||
with self.assertRaises(ValueError):
|
||||
tex_mesh.extend(N=-1)
|
||||
|
||||
def test_padded_to_packed(self):
|
||||
# Case where each face in the mesh has 3 unique uv vertex indices
|
||||
# - i.e. even if a vertex is shared between multiple faces it will
|
||||
# have a unique uv coordinate for each face.
|
||||
N = 2
|
||||
faces_uvs_list = [
|
||||
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
|
||||
torch.tensor([[0, 1, 2], [3, 4, 5]]),
|
||||
] # (N, 3, 3)
|
||||
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
|
||||
|
||||
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
|
||||
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
|
||||
tex = TexturesUV(
|
||||
maps=torch.ones((N, 16, 16, 3)),
|
||||
faces_uvs=faces_uvs_list,
|
||||
verts_uvs=verts_uvs_list,
|
||||
)
|
||||
|
||||
# This is set inside Meshes when textures is passed as an input.
|
||||
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
|
||||
tex1 = tex.clone()
|
||||
tex1._num_faces_per_mesh = num_faces_per_mesh
|
||||
tex1._num_verts_per_mesh = num_verts_per_mesh
|
||||
verts_packed = tex1.verts_uvs_packed()
|
||||
verts_list = tex1.verts_uvs_list()
|
||||
verts_padded = tex1.verts_uvs_padded()
|
||||
|
||||
faces_packed = tex1.faces_uvs_packed()
|
||||
faces_list = tex1.faces_uvs_list()
|
||||
faces_padded = tex1.faces_uvs_padded()
|
||||
|
||||
for f1, f2 in zip(faces_list, faces_uvs_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
for f1, f2 in zip(verts_list, verts_uvs_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
self.assertTrue(faces_packed.shape == (3 + 2, 3))
|
||||
self.assertTrue(faces_padded.shape == (2, 3, 3))
|
||||
self.assertTrue(verts_packed.shape == (9 + 6, 2))
|
||||
self.assertTrue(verts_padded.shape == (2, 9, 2))
|
||||
|
||||
# Case where num_faces_per_mesh is not set and faces_verts_uvs
|
||||
# are initialized with a padded tensor.
|
||||
tex2 = TexturesUV(
|
||||
maps=torch.ones((N, 16, 16, 3)),
|
||||
verts_uvs=verts_padded,
|
||||
faces_uvs=faces_padded,
|
||||
)
|
||||
faces_packed = tex2.faces_uvs_packed()
|
||||
faces_list = tex2.faces_uvs_list()
|
||||
verts_packed = tex2.verts_uvs_packed()
|
||||
verts_list = tex2.verts_uvs_list()
|
||||
|
||||
# Packed is just flattened padded as num_faces_per_mesh
|
||||
# has not been provided.
|
||||
self.assertTrue(faces_packed.shape == (3 * 2, 3))
|
||||
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
|
||||
n = num_faces_per_mesh[i]
|
||||
self.assertTrue((f1[:n] == f2).all().item())
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(verts_list, verts_uvs_list)):
|
||||
n = num_verts_per_mesh[i]
|
||||
self.assertTrue((f1[:n] == f2).all().item())
|
||||
|
||||
def test_to(self):
|
||||
tex = TexturesUV(
|
||||
maps=torch.ones((5, 16, 16, 3)),
|
||||
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
|
||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
tex = tex.to(device)
|
||||
self.assertTrue(tex._faces_uvs_padded.device == device)
|
||||
self.assertTrue(tex._verts_uvs_padded.device == device)
|
||||
self.assertTrue(tex._maps_padded.device == device)
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
V = 20
|
||||
source = {
|
||||
"maps": torch.rand(size=(N, 1, 1, 3)),
|
||||
"faces_uvs": torch.randint(size=(N, 10, 3), high=V),
|
||||
"verts_uvs": torch.randn(size=(N, V, 2)),
|
||||
}
|
||||
tex = TexturesUV(
|
||||
maps=source["maps"],
|
||||
faces_uvs=source["faces_uvs"],
|
||||
verts_uvs=source["verts_uvs"],
|
||||
)
|
||||
|
||||
verts = torch.rand(size=(N, V, 3))
|
||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
tryindex(self, 2, tex, meshes, source)
|
||||
tryindex(self, slice(0, 2, 1), tex, meshes, source)
|
||||
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
tryindex(self, [2, 4], tex, meshes, source)
|
||||
|
Loading…
x
Reference in New Issue
Block a user