Tutorials textures updates and fix bug in extending meshes with uv textures

Summary:
Found a bug in extending textures with vertex uv coordinates. This was due to the padded -> list conversion of vertex uv coordinates i.e.                 The number of vertices in the mesh and in verts_uvs can differ
e.g. if a vertex is shared between 3 faces, it can
have up to 3 different uv coordinates. Therefore we cannot convert directly from padded to list using _num_verts_per_mesh

Reviewed By: bottler

Differential Revision: D23233595

fbshipit-source-id: 0c66d15baae697ead0bdc384f74c27d4c6539fc9
This commit is contained in:
Nikhila Ravi 2020-08-21 19:18:49 -07:00 committed by Facebook GitHub Bot
parent d330765847
commit 90f6a005b0
4 changed files with 956 additions and 943 deletions

View File

@ -1,52 +1,17 @@
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"colab": {
"name": "fit_textured_mesh.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "_Ip8kp4TfBLZ", "id": "_Ip8kp4TfBLZ"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved." "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved."
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -86,25 +51,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "musUWTglgxSB", "id": "musUWTglgxSB"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"!pip install torch torchvision\n", "!pip install torch torchvision\n",
"!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" "!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "nX99zdoffBLg", "id": "nX99zdoffBLg"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"import os\n", "import os\n",
"import torch\n", "import torch\n",
@ -126,28 +93,27 @@
")\n", ")\n",
"\n", "\n",
"# Data structures and functions for rendering\n", "# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes, Textures\n", "from pytorch3d.structures import Meshes\n",
"from pytorch3d.renderer import (\n", "from pytorch3d.renderer import (\n",
" look_at_view_transform,\n", " look_at_view_transform,\n",
" FoVPerspectiveCameras, \n", " OpenGLPerspectiveCameras, \n",
" PointLights, \n", " PointLights, \n",
" DirectionalLights, \n", " DirectionalLights, \n",
" Materials, \n", " Materials, \n",
" RasterizationSettings, \n", " RasterizationSettings, \n",
" MeshRenderer, \n", " MeshRenderer, \n",
" MeshRasterizer, \n", " MeshRasterizer, \n",
" TexturedSoftPhongShader,\n", " SoftPhongShader,\n",
" SoftSilhouetteShader,\n", " SoftSilhouetteShader,\n",
" SoftPhongShader,\n", " SoftPhongShader,\n",
" TexturesVertex\n",
")\n", ")\n",
"\n", "\n",
"# add path for demo utils functions \n", "# add path for demo utils functions \n",
"import sys\n", "import sys\n",
"import os\n", "import os\n",
"sys.path.append(os.path.abspath(''))" "sys.path.append(os.path.abspath(''))"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -161,17 +127,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "HZozr3Pmho-5", "id": "HZozr3Pmho-5"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/master/docs/tutorials/utils/plot_image_grid.py\n", "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/master/docs/tutorials/utils/plot_image_grid.py\n",
"from plot_image_grid import image_grid" "from plot_image_grid import image_grid"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -185,16 +151,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "paJ4Im8ahl7O",
"colab": {}
},
"source": [
" # from utils.plot_image_grid import image_grid"
],
"execution_count": null, "execution_count": null,
"outputs": [] "metadata": {
"colab": {},
"colab_type": "code",
"id": "paJ4Im8ahl7O"
},
"outputs": [],
"source": [
"# from utils.plot_image_grid import image_grid"
]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -210,7 +176,7 @@
"\n", "\n",
"**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. \n", "**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. \n",
"\n", "\n",
"**Textures** is an auxillary datastructure for storing texture information about meshes. \n", "**TexturesVertex** is an auxillary datastructure for storing vertex rgb texture information about meshes. \n",
"\n", "\n",
"**Meshes** has several class methods which are used throughout the rendering pipeline." "**Meshes** has several class methods which are used throughout the rendering pipeline."
] ]
@ -228,27 +194,29 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "tTm0cVuOjb1W", "id": "tTm0cVuOjb1W"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"!mkdir -p data/cow_mesh\n", "!mkdir -p data/cow_mesh\n",
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n", "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n", "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png" "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "gi5Kd0GafBLl", "id": "gi5Kd0GafBLl"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# Setup\n", "# Setup\n",
"if torch.cuda.is_available():\n", "if torch.cuda.is_available():\n",
@ -274,9 +242,7 @@
"scale = max((verts - center).abs().max(0)[0])\n", "scale = max((verts - center).abs().max(0)[0])\n",
"mesh.offset_verts_(-center.expand(N, 3))\n", "mesh.offset_verts_(-center.expand(N, 3))\n",
"mesh.scale_verts_((1.0 / float(scale)));" "mesh.scale_verts_((1.0 / float(scale)));"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -292,11 +258,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code", "colab_type": "code",
"id": "CDQKebNNfBMI", "id": "CDQKebNNfBMI"
"colab": {}
}, },
"outputs": [],
"source": [ "source": [
"# the number of different viewpoints from which we want to render the mesh.\n", "# the number of different viewpoints from which we want to render the mesh.\n",
"num_views = 20\n", "num_views = 20\n",
@ -309,16 +277,16 @@
"# the cow is facing the -z direction. \n", "# the cow is facing the -z direction. \n",
"lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])\n", "lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])\n",
"\n", "\n",
"# Initialize a camera that represents a batch of different \n", "# Initialize an OpenGL perspective camera that represents a batch of different \n",
"# viewing angles. All the cameras helper methods support mixed type inputs and \n", "# viewing angles. All the cameras helper methods support mixed type inputs and \n",
"# broadcasting. So we can view the camera from the a distance of dist=2.7, and \n", "# broadcasting. So we can view the camera from the a distance of dist=2.7, and \n",
"# then specify elevation and azimuth angles for each viewpoint as tensors. \n", "# then specify elevation and azimuth angles for each viewpoint as tensors. \n",
"R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)\n", "R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)\n",
"cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n", "cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)\n",
"\n", "\n",
"# We arbitrarily choose one particular view that will be used to visualize \n", "# We arbitrarily choose one particular view that will be used to visualize \n",
"# results\n", "# results\n",
"camera = FoVPerspectiveCameras(device=device, R=R[None, 1, ...], \n", "camera = OpenGLPerspectiveCameras(device=device, R=R[None, 1, ...], \n",
" T=T[None, 1, ...]) \n", " T=T[None, 1, ...]) \n",
"\n", "\n",
"# Define the settings for rasterization and shading. Here we set the output \n", "# Define the settings for rasterization and shading. Here we set the output \n",
@ -343,7 +311,7 @@
" cameras=camera, \n", " cameras=camera, \n",
" raster_settings=raster_settings\n", " raster_settings=raster_settings\n",
" ),\n", " ),\n",
" shader=TexturedSoftPhongShader(\n", " shader=SoftPhongShader(\n",
" device=device, \n", " device=device, \n",
" cameras=camera,\n", " cameras=camera,\n",
" lights=lights\n", " lights=lights\n",
@ -361,11 +329,9 @@
"# Our multi-view cow dataset will be represented by these 2 lists of tensors,\n", "# Our multi-view cow dataset will be represented by these 2 lists of tensors,\n",
"# each of length num_views.\n", "# each of length num_views.\n",
"target_rgb = [target_images[i, ..., :3] for i in range(num_views)]\n", "target_rgb = [target_images[i, ..., :3] for i in range(num_views)]\n",
"target_cameras = [FoVPerspectiveCameras(device=device, R=R[None, i, ...], \n", "target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...], \n",
" T=T[None, i, ...]) for i in range(num_views)]" " T=T[None, i, ...]) for i in range(num_views)]"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -379,24 +345,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "HHE0CnbVR1Rd", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "HHE0CnbVR1Rd"
}, },
"outputs": [],
"source": [ "source": [
"# RGB images\n", "# RGB images\n",
"image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)\n", "image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)\n",
"plt.show()" "plt.show()"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "gOb4rYx65E8z", "colab_type": "text",
"colab_type": "text" "id": "gOb4rYx65E8z"
}, },
"source": [ "source": [
"Later in this tutorial, we will fit a mesh to the rendered RGB images, as well as to just images of just the cow silhouette. For the latter case, we will render a dataset of silhouette images. Most shaders in PyTorch3D will output an alpha channel along with the RGB image as a 4th channel in an RGBA image. The alpha channel encodes the probability that each pixel belongs to the foreground of the object. We contruct a soft silhouette shader to render this alpha channel." "Later in this tutorial, we will fit a mesh to the rendered RGB images, as well as to just images of just the cow silhouette. For the latter case, we will render a dataset of silhouette images. Most shaders in PyTorch3D will output an alpha channel along with the RGB image as a 4th channel in an RGBA image. The alpha channel encodes the probability that each pixel belongs to the foreground of the object. We contruct a soft silhouette shader to render this alpha channel."
@ -404,11 +370,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "iP_g-nwX4exM", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "iP_g-nwX4exM"
}, },
"outputs": [],
"source": [ "source": [
"# Rasterization settings for silhouette rendering \n", "# Rasterization settings for silhouette rendering \n",
"sigma = 1e-4\n", "sigma = 1e-4\n",
@ -435,9 +403,7 @@
"# Visualize silhouette images\n", "# Visualize silhouette images\n",
"image_grid(silhouette_images.cpu().numpy(), rows=4, cols=5, rgb=False)\n", "image_grid(silhouette_images.cpu().numpy(), rows=4, cols=5, rgb=False)\n",
"plt.show()" "plt.show()"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -454,11 +420,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "eeWYHROrR1Rh", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "eeWYHROrR1Rh"
}, },
"outputs": [],
"source": [ "source": [
"# Show a visualization comparing the rendered predicted mesh to the ground truth \n", "# Show a visualization comparing the rendered predicted mesh to the ground truth \n",
"# mesh\n", "# mesh\n",
@ -487,9 +455,7 @@
" ax.set_xlabel(\"Iteration\", fontsize=\"16\")\n", " ax.set_xlabel(\"Iteration\", fontsize=\"16\")\n",
" ax.set_ylabel(\"Loss\", fontsize=\"16\")\n", " ax.set_ylabel(\"Loss\", fontsize=\"16\")\n",
" ax.set_title(\"Loss vs iterations\", fontsize=\"16\")" " ax.set_title(\"Loss vs iterations\", fontsize=\"16\")"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -503,23 +469,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "i989ARH1R1Rj", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "i989ARH1R1Rj"
}, },
"outputs": [],
"source": [ "source": [
"# We initialize the source shape to be a sphere of radius 1. \n", "# We initialize the source shape to be a sphere of radius 1. \n",
"src_mesh = ico_sphere(4, device)" "src_mesh = ico_sphere(4, device)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "f5xVtgLNDvC5", "colab_type": "text",
"colab_type": "text" "id": "f5xVtgLNDvC5"
}, },
"source": [ "source": [
"We create a new differentiable renderer for rendering the silhouette of our predicted mesh:" "We create a new differentiable renderer for rendering the silhouette of our predicted mesh:"
@ -527,11 +493,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "sXfjzgG4DsDJ", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "sXfjzgG4DsDJ"
}, },
"outputs": [],
"source": [ "source": [
"# Rasterization settings for differentiable rendering, where the blur_radius\n", "# Rasterization settings for differentiable rendering, where the blur_radius\n",
"# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n", "# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n",
@ -551,9 +519,7 @@
" ),\n", " ),\n",
" shader=SoftSilhouetteShader()\n", " shader=SoftSilhouetteShader()\n",
")" ")"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -567,11 +533,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "0sLrKv_MEULh", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "0sLrKv_MEULh"
}, },
"outputs": [],
"source": [ "source": [
"# Number of views to optimize over in each SGD iteration\n", "# Number of views to optimize over in each SGD iteration\n",
"num_views_per_iteration = 2\n", "num_views_per_iteration = 2\n",
@ -609,9 +577,7 @@
"\n", "\n",
"# The optimizer\n", "# The optimizer\n",
"optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)" "optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -625,11 +591,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "gCfepfOoR1Rl", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "gCfepfOoR1Rl"
}, },
"outputs": [],
"source": [ "source": [
"loop = tqdm(range(Niter))\n", "loop = tqdm(range(Niter))\n",
"\n", "\n",
@ -670,25 +638,23 @@
" # Optimization step\n", " # Optimization step\n",
" sum_loss.backward()\n", " sum_loss.backward()\n",
" optimizer.step()" " optimizer.step()"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"scrolled": true, "colab": {},
"id": "CX4huayKR1Rm",
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "CX4huayKR1Rm",
"scrolled": true
}, },
"outputs": [],
"source": [ "source": [
"visualize_prediction(new_src_mesh, silhouette=True, \n", "visualize_prediction(new_src_mesh, silhouette=True, \n",
" target_image=target_silhouette[1])\n", " target_image=target_silhouette[1])\n",
"plot_losses(losses)" "plot_losses(losses)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -698,16 +664,18 @@
}, },
"source": [ "source": [
"## 3. Mesh and texture prediction via textured rendering\n", "## 3. Mesh and texture prediction via textured rendering\n",
"We can predict both the mesh and its texture if we add an additional loss based on the comparing a predicted rendered RGB image to the target image. As before, we start with a sphere mesh. We learn both translational offsets and RGB texture colors for each vertex in the sphere mesh. Since our loss is based on rendered RGB pixel values instead of just the silhouette, we use a **SoftPhongShader** instead of a **SoftSilhouetteShader**. Note also that we use a **SoftPhongShader** instead of the **TexturedSoftPhongShader** used to generate our dataset, because we represent texture using per vertex RGB colors instead of a texture image." "We can predict both the mesh and its texture if we add an additional loss based on the comparing a predicted rendered RGB image to the target image. As before, we start with a sphere mesh. We learn both translational offsets and RGB texture colors for each vertex in the sphere mesh. Since our loss is based on rendered RGB pixel values instead of just the silhouette, we use a **SoftPhongShader** instead of a **SoftSilhouetteShader**."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "aZObyIt9R1Ro", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "aZObyIt9R1Ro"
}, },
"outputs": [],
"source": [ "source": [
"# Rasterization settings for differentiable rendering, where the blur_radius\n", "# Rasterization settings for differentiable rendering, where the blur_radius\n",
"# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n", "# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n",
@ -729,9 +697,7 @@
" cameras=camera,\n", " cameras=camera,\n",
" lights=lights)\n", " lights=lights)\n",
")" ")"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -745,11 +711,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "BS6LAQquF3wq", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "BS6LAQquF3wq"
}, },
"outputs": [],
"source": [ "source": [
"# Number of views to optimize over in each SGD iteration\n", "# Number of views to optimize over in each SGD iteration\n",
"num_views_per_iteration = 2\n", "num_views_per_iteration = 2\n",
@ -781,9 +749,7 @@
"\n", "\n",
"# The optimizer\n", "# The optimizer\n",
"optimizer = torch.optim.SGD([deform_verts, sphere_verts_rgb], lr=1.0, momentum=0.9)" "optimizer = torch.optim.SGD([deform_verts, sphere_verts_rgb], lr=1.0, momentum=0.9)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -797,11 +763,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "EKEH2p8-R1Rr", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "EKEH2p8-R1Rr"
}, },
"outputs": [],
"source": [ "source": [
"loop = tqdm(range(Niter))\n", "loop = tqdm(range(Niter))\n",
"\n", "\n",
@ -813,7 +781,7 @@
" new_src_mesh = src_mesh.offset_verts(deform_verts)\n", " new_src_mesh = src_mesh.offset_verts(deform_verts)\n",
" \n", " \n",
" # Add per vertex colors to texture the mesh\n", " # Add per vertex colors to texture the mesh\n",
" new_src_mesh.textures = Textures(verts_rgb=sphere_verts_rgb) \n", " new_src_mesh.textures = TexturesVertex(verts_rgb=sphere_verts_rgb) \n",
" \n", " \n",
" # Losses to smooth /regularize the mesh shape\n", " # Losses to smooth /regularize the mesh shape\n",
" loss = {k: torch.tensor(0.0, device=device) for k in losses}\n", " loss = {k: torch.tensor(0.0, device=device) for k in losses}\n",
@ -853,30 +821,28 @@
" # Optimization step\n", " # Optimization step\n",
" sum_loss.backward()\n", " sum_loss.backward()\n",
" optimizer.step()\n" " optimizer.step()\n"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"scrolled": true, "colab": {},
"id": "2qTcHO4rR1Rs",
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "2qTcHO4rR1Rs",
"scrolled": true
}, },
"outputs": [],
"source": [ "source": [
"visualize_prediction(new_src_mesh, renderer=renderer_textured, silhouette=False)\n", "visualize_prediction(new_src_mesh, renderer=renderer_textured, silhouette=False)\n",
"plot_losses(losses)" "plot_losses(losses)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "akBOm_xcNUms", "colab_type": "text",
"colab_type": "text" "id": "akBOm_xcNUms"
}, },
"source": [ "source": [
"Save the final predicted mesh:" "Save the final predicted mesh:"
@ -894,11 +860,13 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "OQGhV-psKna8", "colab": {},
"colab_type": "code", "colab_type": "code",
"colab": {} "id": "OQGhV-psKna8"
}, },
"outputs": [],
"source": [ "source": [
"# Fetch the verts and faces of the final predicted mesh\n", "# Fetch the verts and faces of the final predicted mesh\n",
"final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)\n", "final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)\n",
@ -909,9 +877,7 @@
"# Store the predicted mesh using save_obj\n", "# Store the predicted mesh using save_obj\n",
"final_obj = os.path.join('./', 'final_model.obj')\n", "final_obj = os.path.join('./', 'final_model.obj')\n",
"save_obj(final_obj, final_verts, final_faces)" "save_obj(final_obj, final_verts, final_faces)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
@ -924,5 +890,46 @@
"In this tutorial, we learned how to load a textured mesh from an obj file, create a synthetic dataset by rendering the mesh from multiple viewpoints. We showed how to set up an optimization loop to fit a mesh to the observed dataset images based on a rendered silhouette loss. We then augmented this optimization loop with an additional loss based on rendered RGB images, which allowed us to predict both a mesh and its texture." "In this tutorial, we learned how to load a textured mesh from an obj file, create a synthetic dataset by rendering the mesh from multiple viewpoints. We showed how to set up an optimization loop to fit a mesh to the observed dataset images based on a rendered silhouette loss. We then augmented this optimization loop with an additional loss based on rendered RGB images, which allowed us to predict both a mesh and its texture."
] ]
} }
] ],
"metadata": {
"accelerator": "GPU",
"anp_metadata": {
"path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/fit_textured_mesh.ipynb"
},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"colab": {
"name": "fit_textured_mesh.ipynb",
"provenance": [],
"toc_visible": true
},
"disseminate_notebook_info": {
"backup_notebook_id": "781874812352022"
},
"kernelspec": {
"display_name": "intro_to_cv",
"language": "python",
"name": "bento_kernel_intro_to_cv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
}
},
"nbformat": 4,
"nbformat_minor": 1
} }

View File

@ -87,7 +87,7 @@
"from pytorch3d.io import load_objs_as_meshes, load_obj\n", "from pytorch3d.io import load_objs_as_meshes, load_obj\n",
"\n", "\n",
"# Data structures and functions for rendering\n", "# Data structures and functions for rendering\n",
"from pytorch3d.structures import Meshes, Textures\n", "from pytorch3d.structures import Meshes\n",
"from pytorch3d.renderer import (\n", "from pytorch3d.renderer import (\n",
" look_at_view_transform,\n", " look_at_view_transform,\n",
" FoVPerspectiveCameras, \n", " FoVPerspectiveCameras, \n",
@ -97,7 +97,8 @@
" RasterizationSettings, \n", " RasterizationSettings, \n",
" MeshRenderer, \n", " MeshRenderer, \n",
" MeshRasterizer, \n", " MeshRasterizer, \n",
" SoftPhongShader\n", " SoftPhongShader,\n",
" TexturesUV\n",
")\n", ")\n",
"\n", "\n",
"# add path for demo utils functions \n", "# add path for demo utils functions \n",
@ -170,7 +171,7 @@
"\n", "\n",
"**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. \n", "**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. \n",
"\n", "\n",
"**Textures** is an auxillary datastructure for storing texture information about meshes. \n", "**TexturesUV** is an auxillary datastructure for storing vertex uv and texture maps for meshes. \n",
"\n", "\n",
"**Meshes** has several class methods which are used throughout the rendering pipeline." "**Meshes** has several class methods which are used throughout the rendering pipeline."
] ]
@ -537,7 +538,7 @@
"source": [ "source": [
"# We can pass arbirary keyword arguments to the rasterizer/shader via the renderer\n", "# We can pass arbirary keyword arguments to the rasterizer/shader via the renderer\n",
"# so the renderer does not need to be reinitialized if any of the settings change.\n", "# so the renderer does not need to be reinitialized if any of the settings change.\n",
"images = renderer(meshes, cameras=cameras, lights=lights)" "images = renderer(mesh, cameras=cameras, lights=lights)"
] ]
}, },
{ {
@ -582,9 +583,9 @@
"backup_notebook_id": "569222367081034" "backup_notebook_id": "569222367081034"
}, },
"kernelspec": { "kernelspec": {
"display_name": "pytorch3d (local)", "display_name": "intro_to_cv",
"language": "python", "language": "python",
"name": "pytorch3d_local" "name": "bento_kernel_intro_to_cv"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

View File

@ -599,11 +599,6 @@ class TexturesUV(TexturesBase):
if not all(v.device == self.device for v in verts_uvs): if not all(v.device == self.device for v in verts_uvs):
raise ValueError("verts_uvs and faces_uvs must be on the same device") raise ValueError("verts_uvs and faces_uvs must be on the same device")
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
self._num_verts_per_mesh = [len(v) for v in verts_uvs]
elif torch.is_tensor(verts_uvs): elif torch.is_tensor(verts_uvs):
if ( if (
verts_uvs.ndim != 3 verts_uvs.ndim != 3
@ -621,7 +616,6 @@ class TexturesUV(TexturesBase):
# These values may be overridden when textures is # These values may be overridden when textures is
# passed into the Meshes constructor. # passed into the Meshes constructor.
max_V = verts_uvs.shape[1] max_V = verts_uvs.shape[1]
self._num_verts_per_mesh = [max_V] * self._N
else: else:
raise ValueError("Expected verts_uvs to be a tensor or list") raise ValueError("Expected verts_uvs to be a tensor or list")
@ -758,9 +752,11 @@ class TexturesUV(TexturesBase):
torch.empty((0, 2), dtype=torch.float32, device=self.device) torch.empty((0, 2), dtype=torch.float32, device=self.device)
] * self._N ] * self._N
else: else:
self._verts_uvs_list = padded_to_list( # The number of vertices in the mesh and in verts_uvs can differ
self._verts_uvs_padded, split_size=self._num_verts_per_mesh # e.g. if a vertex is shared between 3 faces, it can
) # have up to 3 different uv coordinates. Therefore we cannot
# convert directly from padded to list using _num_verts_per_mesh
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
return self._verts_uvs_list return self._verts_uvs_list
# Currently only the padded maps are used. # Currently only the padded maps are used.
@ -783,7 +779,6 @@ class TexturesUV(TexturesBase):
"verts_uvs_padded", "verts_uvs_padded",
"faces_uvs_padded", "faces_uvs_padded",
"_num_faces_per_mesh", "_num_faces_per_mesh",
"_num_verts_per_mesh",
], ],
) )
new_tex = TexturesUV( new_tex = TexturesUV(
@ -791,8 +786,8 @@ class TexturesUV(TexturesBase):
faces_uvs=new_props["faces_uvs_padded"], faces_uvs=new_props["faces_uvs_padded"],
verts_uvs=new_props["verts_uvs_padded"], verts_uvs=new_props["verts_uvs_padded"],
) )
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
new_tex._num_verts_per_mesh = new_props["_num_verts_per_mesh"]
return new_tex return new_tex
def sample_textures(self, fragments, **kwargs) -> torch.Tensor: def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
@ -860,6 +855,7 @@ class TexturesUV(TexturesBase):
# right-bottom pixel of input. # right-bottom pixel of input.
pixel_uvs = pixel_uvs * 2.0 - 1.0 pixel_uvs = pixel_uvs * 2.0 - 1.0
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device: if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device) texture_maps = texture_maps.to(pixel_uvs.device)

View File

@ -588,10 +588,19 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
tex_init = tex_mesh.textures tex_init = tex_mesh.textures
new_tex = new_mesh.textures new_tex = new_mesh.textures
new_tex_num_verts = new_mesh.num_verts_per_mesh()
for i in range(len(tex_mesh)): for i in range(len(tex_mesh)):
for n in range(N): for n in range(N):
tex_nv = new_tex_num_verts[i * N + n]
self.assertClose( self.assertClose(
tex_init.verts_uvs_list()[i], new_tex.verts_uvs_list()[i * N + n] # The original textures were initialized using
# verts uvs list
tex_init.verts_uvs_list()[i],
# In the new textures, the verts_uvs are initialized
# from padded. The verts per mesh are not used to
# convert from padded to list. See TexturesUV for an
# explanation.
new_tex.verts_uvs_list()[i * N + n][:tex_nv, ...],
) )
self.assertClose( self.assertClose(
tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n] tex_init.faces_uvs_list()[i], new_tex.faces_uvs_list()[i * N + n]