# coding: utf-8 # In[ ]: # Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. # # Deform a source mesh to form a target mesh using 3D loss functions # In this tutorial, we learn to deform an initial generic shape (e.g. sphere) to fit a target shape. # # We will cover: # # - How to **load a mesh** from an `.obj` file # - How to use the PyTorch3D **Meshes** datastructure # - How to use 4 different PyTorch3D **mesh loss functions** # - How to set up an **optimization loop** # # # Starting from a sphere mesh, we learn the offset to each vertex in the mesh such that # the predicted mesh is closer to the target mesh at each optimization step. To achieve this we minimize: # # + `chamfer_distance`, the distance between the predicted (deformed) and target mesh, defined as the chamfer distance between the set of pointclouds resulting from **differentiably sampling points** from their surfaces. # # However, solely minimizing the chamfer distance between the predicted and the target mesh will lead to a non-smooth shape (verify this by setting `w_chamfer=1.0` and all other weights to `0.0`). # # We enforce smoothness by adding **shape regularizers** to the objective. Namely, we add: # # + `mesh_edge_length`, which minimizes the length of the edges in the predicted mesh. # + `mesh_normal_consistency`, which enforces consistency across the normals of neighboring faces. # + `mesh_laplacian_smoothing`, which is the laplacian regularizer. # ## 0. Install and Import modules # Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell: # In[ ]: import os import sys import torch need_pytorch3d=False try: import pytorch3d except ModuleNotFoundError: need_pytorch3d=True if need_pytorch3d: if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"): # We try to install PyTorch3D via a released wheel. pyt_version_str=torch.__version__.split("+")[0].replace(".", "") version_str="".join([ f"py3{sys.version_info.minor}_cu", torch.version.cuda.replace(".",""), f"_pyt{pyt_version_str}" ]) get_ipython().system('pip install fvcore iopath') get_ipython().system('pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html') else: # We try to install PyTorch3D from source. get_ipython().system("pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'") # In[ ]: import os import torch from pytorch3d.io import load_obj, save_obj from pytorch3d.structures import Meshes from pytorch3d.utils import ico_sphere from pytorch3d.ops import sample_points_from_meshes from pytorch3d.loss import ( chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency, ) import numpy as np from tqdm.notebook import tqdm get_ipython().run_line_magic('matplotlib', 'notebook') from mpl_toolkits.mplot3d import Axes3D import matplotlib.pyplot as plt import matplotlib as mpl mpl.rcParams['savefig.dpi'] = 80 mpl.rcParams['figure.dpi'] = 80 # Set the device if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") print("WARNING: CPU only, this will be slow!") # ## 1. Load an obj file and create a Meshes object # Download the target 3D model of a dolphin. It will be saved locally as a file called `dolphin.obj`. # In[ ]: get_ipython().system('wget https://dl.fbaipublicfiles.com/pytorch3d/data/dolphin/dolphin.obj') # In[ ]: # Load the dolphin mesh. trg_obj = 'dolphin.obj' # In[ ]: # We read the target 3D model using load_obj verts, faces, aux = load_obj(trg_obj) # verts is a FloatTensor of shape (V, 3) where V is the number of vertices in the mesh # faces is an object which contains the following LongTensors: verts_idx, normals_idx and textures_idx # For this tutorial, normals and textures are ignored. faces_idx = faces.verts_idx.to(device) verts = verts.to(device) # We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). # (scale, center) will be used to bring the predicted mesh to its original center and scale # Note that normalizing the target mesh, speeds up the optimization but is not necessary! center = verts.mean(0) verts = verts - center scale = max(verts.abs().max(0)[0]) verts = verts / scale # We construct a Meshes structure for the target mesh trg_mesh = Meshes(verts=[verts], faces=[faces_idx]) # In[ ]: # We initialize the source shape to be a sphere of radius 1 src_mesh = ico_sphere(4, device) # ## 2. Visualize the source and target meshes # In[ ]: def plot_pointcloud(mesh, title=""): # Sample points uniformly from the surface of the mesh. points = sample_points_from_meshes(mesh, 5000) x, y, z = points.clone().detach().cpu().squeeze().unbind(1) fig = plt.figure(figsize=(5, 5)) ax = fig.add_subplot(111, projection='3d') ax.scatter3D(x, z, -y) ax.set_xlabel('x') ax.set_ylabel('z') ax.set_zlabel('y') ax.set_title(title) ax.view_init(190, 30) plt.show() # In[ ]: # %matplotlib notebook plot_pointcloud(trg_mesh, "Target mesh") plot_pointcloud(src_mesh, "Source mesh") # ## 3. Optimization loop # In[ ]: # We will learn to deform the source mesh by offsetting its vertices # The shape of the deform parameters is equal to the total number of vertices in src_mesh deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True) # In[ ]: # The optimizer optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9) # In[ ]: # Number of optimization steps Niter = 2000 # Weight for the chamfer loss w_chamfer = 1.0 # Weight for mesh edge loss w_edge = 1.0 # Weight for mesh normal consistency w_normal = 0.01 # Weight for mesh laplacian smoothing w_laplacian = 0.1 # Plot period for the losses plot_period = 250 loop = tqdm(range(Niter)) chamfer_losses = [] laplacian_losses = [] edge_losses = [] normal_losses = [] get_ipython().run_line_magic('matplotlib', 'inline') for i in loop: # Initialize optimizer optimizer.zero_grad() # Deform the mesh new_src_mesh = src_mesh.offset_verts(deform_verts) # We sample 5k points from the surface of each mesh sample_trg = sample_points_from_meshes(trg_mesh, 5000) sample_src = sample_points_from_meshes(new_src_mesh, 5000) # We compare the two sets of pointclouds by computing (a) the chamfer loss loss_chamfer, _ = chamfer_distance(sample_trg, sample_src) # and (b) the edge length of the predicted mesh loss_edge = mesh_edge_loss(new_src_mesh) # mesh normal consistency loss_normal = mesh_normal_consistency(new_src_mesh) # mesh laplacian smoothing loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform") # Weighted sum of the losses loss = loss_chamfer * w_chamfer + loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian # Print the losses loop.set_description('total_loss = %.6f' % loss) # Save the losses for plotting chamfer_losses.append(float(loss_chamfer.detach().cpu())) edge_losses.append(float(loss_edge.detach().cpu())) normal_losses.append(float(loss_normal.detach().cpu())) laplacian_losses.append(float(loss_laplacian.detach().cpu())) # Plot mesh if i % plot_period == 0: plot_pointcloud(new_src_mesh, title="iter: %d" % i) # Optimization step loss.backward() optimizer.step() # ## 4. Visualize the loss # In[ ]: fig = plt.figure(figsize=(13, 5)) ax = fig.gca() ax.plot(chamfer_losses, label="chamfer loss") ax.plot(edge_losses, label="edge loss") ax.plot(normal_losses, label="normal loss") ax.plot(laplacian_losses, label="laplacian loss") ax.legend(fontsize="16") ax.set_xlabel("Iteration", fontsize="16") ax.set_ylabel("Loss", fontsize="16") ax.set_title("Loss vs iterations", fontsize="16"); # ## 5. Save the predicted mesh # In[ ]: # Fetch the verts and faces of the final predicted mesh final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0) # Scale normalize back to the original target size final_verts = final_verts * scale + center # Store the predicted mesh using save_obj final_obj = 'final_model.obj' save_obj(final_obj, final_verts, final_faces) # ## 6. Conclusion # # In this tutorial we learnt how to load a mesh from an obj file, initialize a PyTorch3D datastructure called **Meshes**, set up an optimization loop and use four different PyTorch3D mesh loss functions.