TexturesAtlas in plotly

Summary:
Lets a K=1 textures atlas be viewed in plotly. Fixes https://github.com/facebookresearch/pytorch3d/issues/916 .

Test: Now get colored faces in
```
import torch
from pytorch3d.utils import ico_sphere
from pytorch3d.vis.plotly_vis import plot_batch_individually
from pytorch3d.renderer import TexturesAtlas

b = ico_sphere()
face_colors = torch.rand(b.faces_padded().shape)
tex = TexturesAtlas(face_colors[:,:,None,None])
b.textures=tex
plot_batch_individually(b)
```

Reviewed By: gkioxari

Differential Revision: D32190470

fbshipit-source-id: 258d30b7e9d79751a79db44684b5540657a2eff5
This commit is contained in:
Jeremy Reizenstein 2021-11-11 02:14:37 -08:00 committed by Facebook GitHub Bot
parent 5fbdb99aec
commit 7ce18f38cd

View File

@ -10,7 +10,12 @@ from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from pytorch3d.renderer import RayBundle, TexturesVertex, ray_bundle_to_ray_points
from pytorch3d.renderer import (
RayBundle,
TexturesAtlas,
TexturesVertex,
ray_bundle_to_ray_points,
)
from pytorch3d.renderer.camera_utils import camera_to_eye_at_up
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Meshes, Pointclouds, join_meshes_as_scene
@ -580,13 +585,19 @@ def _add_mesh_trace(
mesh = mesh.detach().cpu()
verts = mesh.verts_packed()
faces = mesh.faces_packed()
# If mesh has vertex colors defined as texture, use vertex colors
# If mesh has vertex colors or face colors, use them
# for figure, otherwise use plotly's default colors.
verts_rgb = None
faces_rgb = None
if isinstance(mesh.textures, TexturesVertex):
verts_rgb = mesh.textures.verts_features_packed()
verts_rgb.clamp_(min=0.0, max=1.0)
verts_rgb = torch.tensor(255.0) * verts_rgb
if isinstance(mesh.textures, TexturesAtlas):
atlas = mesh.textures.atlas_packed()
# If K==1
if atlas.shape[1] == 1 and atlas.shape[3] == 3:
faces_rgb = atlas[:, 0, 0]
# Reposition the unused vertices to be "inside" the object
# (i.e. they won't be visible in the plot).
@ -602,6 +613,7 @@ def _add_mesh_trace(
y=verts[:, 1],
z=verts[:, 2],
vertexcolor=verts_rgb,
facecolor=faces_rgb,
i=faces[:, 0],
j=faces[:, 1],
k=faces[:, 2],