diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 05749219..a2c7dfbe 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -10,7 +10,12 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union import plotly.graph_objects as go import torch from plotly.subplots import make_subplots -from pytorch3d.renderer import RayBundle, TexturesVertex, ray_bundle_to_ray_points +from pytorch3d.renderer import ( + RayBundle, + TexturesAtlas, + TexturesVertex, + ray_bundle_to_ray_points, +) from pytorch3d.renderer.camera_utils import camera_to_eye_at_up from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene @@ -580,13 +585,19 @@ def _add_mesh_trace( mesh = mesh.detach().cpu() verts = mesh.verts_packed() faces = mesh.faces_packed() - # If mesh has vertex colors defined as texture, use vertex colors + # If mesh has vertex colors or face colors, use them # for figure, otherwise use plotly's default colors. verts_rgb = None + faces_rgb = None if isinstance(mesh.textures, TexturesVertex): verts_rgb = mesh.textures.verts_features_packed() verts_rgb.clamp_(min=0.0, max=1.0) verts_rgb = torch.tensor(255.0) * verts_rgb + if isinstance(mesh.textures, TexturesAtlas): + atlas = mesh.textures.atlas_packed() + # If K==1 + if atlas.shape[1] == 1 and atlas.shape[3] == 3: + faces_rgb = atlas[:, 0, 0] # Reposition the unused vertices to be "inside" the object # (i.e. they won't be visible in the plot). @@ -602,6 +613,7 @@ def _add_mesh_trace( y=verts[:, 1], z=verts[:, 2], vertexcolor=verts_rgb, + facecolor=faces_rgb, i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],