mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
315 lines
11 KiB
Python
315 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
|
|
# # Camera position optimization using differentiable rendering
|
|
#
|
|
# In this tutorial we will learn the [x, y, z] position of a camera given a reference image using differentiable rendering.
|
|
#
|
|
# We will first initialize a renderer with a starting position for the camera. We will then use this to generate an image, compute a loss with the reference image, and finally backpropagate through the entire pipeline to update the position of the camera.
|
|
#
|
|
# This tutorial shows how to:
|
|
# - load a mesh from an `.obj` file
|
|
# - initialize a `Camera`, `Shader` and `Renderer`,
|
|
# - render a mesh
|
|
# - set up an optimization loop with a loss function and optimizer
|
|
#
|
|
|
|
# ## 0. Install and import modules
|
|
|
|
# If `torch`, `torchvision` and `pytorch3d` are not installed, run the following cell:
|
|
|
|
# In[ ]:
|
|
|
|
|
|
get_ipython().system('pip install torch torchvision')
|
|
import sys
|
|
import torch
|
|
if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):
|
|
get_ipython().system('pip install pytorch3d')
|
|
else:
|
|
get_ipython().system("pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'")
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
from tqdm.notebook import tqdm
|
|
import imageio
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import matplotlib.pyplot as plt
|
|
from skimage import img_as_ubyte
|
|
|
|
# io utils
|
|
from pytorch3d.io import load_obj
|
|
|
|
# datastructures
|
|
from pytorch3d.structures import Meshes
|
|
|
|
# 3D transformations functions
|
|
from pytorch3d.transforms import Rotate, Translate
|
|
|
|
# rendering components
|
|
from pytorch3d.renderer import (
|
|
FoVPerspectiveCameras, look_at_view_transform, look_at_rotation,
|
|
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
|
|
SoftSilhouetteShader, HardPhongShader, PointLights, TexturesVertex,
|
|
)
|
|
|
|
|
|
# ## 1. Load the Obj
|
|
#
|
|
# We will load an obj file and create a **Meshes** object. **Meshes** is a unique datastructure provided in PyTorch3D for working with **batches of meshes of different sizes**. It has several useful class methods which are used in the rendering pipeline.
|
|
|
|
# If you are running this notebook locally after cloning the PyTorch3D repository, the mesh will already be available. **If using Google Colab, fetch the mesh and save it at the path `data/`**:
|
|
|
|
# In[ ]:
|
|
|
|
|
|
get_ipython().system('mkdir -p data')
|
|
get_ipython().system('wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj')
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# Set the cuda device
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda:0")
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
# Load the obj and ignore the textures and materials.
|
|
verts, faces_idx, _ = load_obj("./data/teapot.obj")
|
|
faces = faces_idx.verts_idx
|
|
|
|
# Initialize each vertex to be white in color.
|
|
verts_rgb = torch.ones_like(verts)[None] # (1, V, 3)
|
|
textures = TexturesVertex(verts_features=verts_rgb.to(device))
|
|
|
|
# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
|
|
teapot_mesh = Meshes(
|
|
verts=[verts.to(device)],
|
|
faces=[faces.to(device)],
|
|
textures=textures
|
|
)
|
|
|
|
|
|
#
|
|
#
|
|
# ## 2. Optimization setup
|
|
|
|
# ### Create a renderer
|
|
#
|
|
# A **renderer** in PyTorch3D is composed of a **rasterizer** and a **shader** which each have a number of subcomponents such as a **camera** (orthgraphic/perspective). Here we initialize some of these components and use default values for the rest.
|
|
#
|
|
# For optimizing the camera position we will use a renderer which produces a **silhouette** of the object only and does not apply any **lighting** or **shading**. We will also initialize another renderer which applies full **phong shading** and use this for visualizing the outputs.
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# Initialize a perspective camera.
|
|
cameras = FoVPerspectiveCameras(device=device)
|
|
|
|
# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of
|
|
# edges. Refer to blending.py for more details.
|
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
|
|
|
|
# Define the settings for rasterization and shading. Here we set the output image to be of size
|
|
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that
|
|
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for
|
|
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of
|
|
# the difference between naive and coarse-to-fine rasterization.
|
|
raster_settings = RasterizationSettings(
|
|
image_size=256,
|
|
blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma,
|
|
faces_per_pixel=100,
|
|
)
|
|
|
|
# Create a silhouette mesh renderer by composing a rasterizer and a shader.
|
|
silhouette_renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(
|
|
cameras=cameras,
|
|
raster_settings=raster_settings
|
|
),
|
|
shader=SoftSilhouetteShader(blend_params=blend_params)
|
|
)
|
|
|
|
|
|
# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
|
|
raster_settings = RasterizationSettings(
|
|
image_size=256,
|
|
blur_radius=0.0,
|
|
faces_per_pixel=1,
|
|
)
|
|
# We can add a point light in front of the object.
|
|
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
|
|
phong_renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(
|
|
cameras=cameras,
|
|
raster_settings=raster_settings
|
|
),
|
|
shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
|
|
)
|
|
|
|
|
|
# ### Create a reference image
|
|
#
|
|
# We will first position the teapot and generate an image. We use helper functions to rotate the teapot to a desired viewpoint. Then we can use the renderers to produce an image. Here we will use both renderers and visualize the silhouette and full shaded image.
|
|
#
|
|
# The world coordinate system is defined as +Y up, +X left and +Z in. The teapot in world coordinates has the spout pointing to the left.
|
|
#
|
|
# We defined a camera which is positioned on the positive z axis hence sees the spout to the right.
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# Select the viewpoint using spherical angles
|
|
distance = 3 # distance from camera to the object
|
|
elevation = 50.0 # angle of elevation in degrees
|
|
azimuth = 0.0 # No rotation so the camera is positioned on the +Z axis.
|
|
|
|
# Get the position of the camera based on the spherical angles
|
|
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)
|
|
|
|
# Render the teapot providing the values of R and T.
|
|
silhouete = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
|
|
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)
|
|
|
|
silhouete = silhouete.cpu().numpy()
|
|
image_ref = image_ref.cpu().numpy()
|
|
|
|
plt.figure(figsize=(10, 10))
|
|
plt.subplot(1, 2, 1)
|
|
plt.imshow(silhouete.squeeze()[..., 3]) # only plot the alpha channel of the RGBA image
|
|
plt.grid(False)
|
|
plt.subplot(1, 2, 2)
|
|
plt.imshow(image_ref.squeeze())
|
|
plt.grid(False)
|
|
|
|
|
|
# ### Set up a basic model
|
|
#
|
|
# Here we create a simple model class and initialize a parameter for the camera position.
|
|
|
|
# In[ ]:
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, meshes, renderer, image_ref):
|
|
super().__init__()
|
|
self.meshes = meshes
|
|
self.device = meshes.device
|
|
self.renderer = renderer
|
|
|
|
# Get the silhouette of the reference RGB image by finding all the non zero values.
|
|
image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 0).astype(np.float32))
|
|
self.register_buffer('image_ref', image_ref)
|
|
|
|
# Create an optimizable parameter for the x, y, z position of the camera.
|
|
self.camera_position = nn.Parameter(
|
|
torch.from_numpy(np.array([3.0, 6.9, +2.5], dtype=np.float32)).to(meshes.device))
|
|
|
|
def forward(self):
|
|
|
|
# Render the image using the updated camera position. Based on the new position of the
|
|
# camer we calculate the rotation and translation matrices
|
|
R = look_at_rotation(self.camera_position[None, :], device=self.device) # (1, 3, 3)
|
|
T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0] # (1, 3)
|
|
|
|
image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
|
|
|
|
# Calculate the silhouette loss
|
|
loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
|
|
return loss, image
|
|
|
|
|
|
|
|
# ## 3. Initialize the model and optimizer
|
|
#
|
|
# Now we can create an instance of the **model** above and set up an **optimizer** for the camera position parameter.
|
|
|
|
# In[ ]:
|
|
|
|
|
|
# We will save images periodically and compose them into a GIF.
|
|
filename_output = "./teapot_optimization_demo.gif"
|
|
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)
|
|
|
|
# Initialize a model using the renderer, mesh and reference image
|
|
model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)
|
|
|
|
# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
|
|
|
|
|
|
# ### Visualize the starting position and the reference position
|
|
|
|
# In[ ]:
|
|
|
|
|
|
plt.figure(figsize=(10, 10))
|
|
|
|
_, image_init = model()
|
|
plt.subplot(1, 2, 1)
|
|
plt.imshow(image_init.detach().squeeze().cpu().numpy()[..., 3])
|
|
plt.grid(False)
|
|
plt.title("Starting position")
|
|
|
|
plt.subplot(1, 2, 2)
|
|
plt.imshow(model.image_ref.cpu().numpy().squeeze())
|
|
plt.grid(False)
|
|
plt.title("Reference silhouette");
|
|
|
|
|
|
# ## 4. Run the optimization
|
|
#
|
|
# We run several iterations of the forward and backward pass and save outputs every 10 iterations. When this has finished take a look at `./teapot_optimization_demo.gif` for a cool gif of the optimization process!
|
|
|
|
# In[ ]:
|
|
|
|
|
|
loop = tqdm(range(200))
|
|
for i in loop:
|
|
optimizer.zero_grad()
|
|
loss, _ = model()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loop.set_description('Optimizing (loss %.4f)' % loss.data)
|
|
|
|
if loss.item() < 200:
|
|
break
|
|
|
|
# Save outputs to create a GIF.
|
|
if i % 10 == 0:
|
|
R = look_at_rotation(model.camera_position[None, :], device=model.device)
|
|
T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0] # (1, 3)
|
|
image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
|
|
image = image[0, ..., :3].detach().squeeze().cpu().numpy()
|
|
image = img_as_ubyte(image)
|
|
writer.append_data(image)
|
|
|
|
plt.figure()
|
|
plt.imshow(image[..., :3])
|
|
plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
|
|
plt.grid("off")
|
|
plt.axis("off")
|
|
|
|
writer.close()
|
|
|
|
|
|
# ## 5. Conclusion
|
|
#
|
|
# In this tutorial we learnt how to **load** a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an **Renderer** consisting of a **Rasterizer** and a **Shader**, set up an optimization loop including a **Model** and a **loss function**, and run the optimization.
|