mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
images for debugging TexturesUV
Summary: New methods to directly plot a TexturesUV map with its used points, using PIL and matplotlib. Reviewed By: gkioxari Differential Revision: D23782968 fbshipit-source-id: 692970857b5be13a35a3175dc82ac03963a73555
This commit is contained in:
parent
b149bbfb3c
commit
aa4cc0adbc
@ -93,7 +93,7 @@
|
||||
"\n",
|
||||
"# Data structures and functions for rendering\n",
|
||||
"from pytorch3d.structures import Meshes\n",
|
||||
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene\n",
|
||||
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene, texturesuv_image_matplotlib\n",
|
||||
"from pytorch3d.renderer import (\n",
|
||||
" look_at_view_transform,\n",
|
||||
" FoVPerspectiveCameras, \n",
|
||||
@ -236,8 +236,7 @@
|
||||
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
|
||||
"\n",
|
||||
"# Load obj file\n",
|
||||
"mesh = load_objs_as_meshes([obj_filename], device=device)\n",
|
||||
"texture_image=mesh.textures.maps_padded()"
|
||||
"mesh = load_objs_as_meshes([obj_filename], device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -265,9 +264,29 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(7,7))\n",
|
||||
"texture_image=mesh.textures.maps_padded()\n",
|
||||
"plt.imshow(texture_image.squeeze().cpu().numpy())\n",
|
||||
"plt.grid(\"off\");\n",
|
||||
"plt.axis('off');"
|
||||
"plt.axis(\"off\");"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"PyTorch3D has a built-in way to view the texture map with matplotlib along with the points on the map corresponding to vertices. There is also a method, texturesuv_image_PIL, to get a similar image which can be saved to a file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(7,7))\n",
|
||||
"texturesuv_image_matplotlib(mesh.textures, subsample=None)\n",
|
||||
"plt.grid(\"off\");\n",
|
||||
"plt.axis(\"off\");"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1174,6 +1174,42 @@ class TexturesUV(TexturesBase):
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
def centers_for_image(self, index):
|
||||
"""
|
||||
Return the locations in the texture map which correspond to the given
|
||||
verts_uvs, for one of the meshes. This is potentially useful for
|
||||
visualizing the data. See the texturesuv_image_matplotlib and
|
||||
texturesuv_image_PIL functions.
|
||||
|
||||
Args:
|
||||
index: batch index of the mesh whose centers to return.
|
||||
|
||||
Returns:
|
||||
centers: coordinates of points in the texture image
|
||||
- a FloatTensor of shape (V,2)
|
||||
"""
|
||||
if self._N != 1:
|
||||
raise ValueError(
|
||||
"This function only supports plotting textures for one mesh."
|
||||
)
|
||||
texture_image = self.maps_padded()
|
||||
verts_uvs = self.verts_uvs_list()[index][None]
|
||||
_, H, W, _3 = texture_image.shape
|
||||
coord1 = torch.arange(W).expand(H, W)
|
||||
coord2 = torch.arange(H)[:, None].expand(H, W)
|
||||
coords = torch.stack([coord1, coord2])[None]
|
||||
with torch.no_grad():
|
||||
# Get xy cartesian coordinates based on the uv coordinates
|
||||
centers = F.grid_sample(
|
||||
torch.flip(coords.to(texture_image), [2]),
|
||||
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
|
||||
verts_uvs[:, None] * 2.0 - 1,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
).cpu()
|
||||
centers = centers[0, :, 0].T
|
||||
return centers
|
||||
|
||||
|
||||
class TexturesVertex(TexturesBase):
|
||||
def __init__(
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene
|
||||
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
104
pytorch3d/vis/texture_vis.py
Normal file
104
pytorch3d/vis/texture_vis.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from pytorch3d.renderer.mesh import TexturesUV
|
||||
|
||||
|
||||
def texturesuv_image_matplotlib(
|
||||
texture: TexturesUV,
|
||||
*,
|
||||
texture_index: int = 0,
|
||||
radius: float = 1,
|
||||
color=(1.0, 0.0, 0.0),
|
||||
subsample: Optional[int] = 10000,
|
||||
origin: str = "upper",
|
||||
):
|
||||
"""
|
||||
Plot the texture image for one element of a TexturesUV with
|
||||
matplotlib together with verts_uvs positions circled.
|
||||
In particular a value in verts_uvs which is never referenced
|
||||
in faces_uvs will still be plotted.
|
||||
This is for debugging purposes, e.g. to align the map with
|
||||
the uv coordinates. In particular, matplotlib
|
||||
is used which is not an official dependency of PyTorch3D.
|
||||
|
||||
Args:
|
||||
texture: a TexturesUV object with one mesh
|
||||
texture_index: index in the batch to plot
|
||||
radius: plotted circle radius in pixels
|
||||
color: any matplotlib-understood color for the circles.
|
||||
subsample: if not None, number of points to plot.
|
||||
Otherwise all points are plotted.
|
||||
origin: "upper" or "lower" like matplotlib.imshow
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.patches import Circle
|
||||
|
||||
texture_image = texture.maps_padded()
|
||||
centers = texture.centers_for_image(index=texture_index).numpy()
|
||||
|
||||
ax = plt.gca()
|
||||
ax.imshow(texture_image[texture_index].detach().cpu().numpy(), origin=origin)
|
||||
|
||||
n_points = centers.shape[0]
|
||||
if subsample is None or n_points <= subsample:
|
||||
indices = range(n_points)
|
||||
else:
|
||||
indices = np.random.choice(n_points, subsample, replace=False)
|
||||
for i in indices:
|
||||
# setting clip_on=False makes it obvious when
|
||||
# we have UV coordinates outside the correct range
|
||||
ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False))
|
||||
|
||||
|
||||
def texturesuv_image_PIL(
|
||||
texture: TexturesUV,
|
||||
*,
|
||||
texture_index: int = 0,
|
||||
radius: float = 1,
|
||||
color="red",
|
||||
subsample: Optional[int] = 10000,
|
||||
):
|
||||
"""
|
||||
Return a PIL image of the texture image of one element of the batch
|
||||
from a TexturesUV, together with the verts_uvs positions circled.
|
||||
In particular a value in verts_uvs which is never referenced
|
||||
in faces_uvs will still be plotted.
|
||||
This is for debugging purposes, e.g. to align the map with
|
||||
the uv coordinates. In particular, matplotlib
|
||||
is used which is not an official dependency of PyTorch3D.
|
||||
|
||||
Args:
|
||||
texture: a TexturesUV object with one mesh
|
||||
texture_index: index in the batch to plot
|
||||
radius: plotted circle radius in pixels
|
||||
color: any PIL-understood color for the circles.
|
||||
subsample: if not None, number of points to plot.
|
||||
Otherwise all points are plotted.
|
||||
|
||||
Returns:
|
||||
PIL Image object.
|
||||
"""
|
||||
|
||||
centers = texture.centers_for_image(index=texture_index).numpy()
|
||||
texture_image = texture.maps_padded()
|
||||
texture_array = (texture_image[texture_index] * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(texture_array)
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
n_points = centers.shape[0]
|
||||
if subsample is None or n_points <= subsample:
|
||||
indices = range(n_points)
|
||||
else:
|
||||
indices = np.random.choice(n_points, subsample, replace=False)
|
||||
|
||||
for i in indices:
|
||||
x = centers[i][0]
|
||||
y = centers[i][1]
|
||||
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color)
|
||||
|
||||
return image
|
BIN
tests/data/texturesuv_debug.png
Normal file
BIN
tests/data/texturesuv_debug.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 95 KiB |
@ -2,10 +2,13 @@
|
||||
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from common_testing import TestCaseMixin
|
||||
from PIL import Image
|
||||
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
||||
from pytorch3d.renderer.mesh.textures import (
|
||||
TexturesAtlas,
|
||||
@ -15,9 +18,14 @@ from pytorch3d.renderer.mesh.textures import (
|
||||
pack_rectangles,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
|
||||
from pytorch3d.vis import texturesuv_image_PIL
|
||||
from test_meshes import TestMeshes
|
||||
|
||||
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
def tryindex(self, index, tex, meshes, source):
|
||||
tex2 = tex[index]
|
||||
meshes2 = meshes[index]
|
||||
@ -471,6 +479,10 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
|
||||
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_sample_textures_uv(self):
|
||||
barycentric_coords = torch.tensor(
|
||||
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
|
||||
@ -821,6 +833,22 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
tryindex(self, [2, 4], tex, meshes, source)
|
||||
|
||||
def test_png_debug(self):
|
||||
maps = torch.rand(size=(1, 256, 128, 3)) * torch.tensor([0.8, 1, 0.8])
|
||||
verts_uvs = torch.rand(size=(1, 20, 2))
|
||||
faces_uvs = torch.zeros(size=(1, 0, 3), dtype=torch.int64)
|
||||
tex = TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
|
||||
|
||||
image = texturesuv_image_PIL(tex, radius=3)
|
||||
image_out = np.array(image)
|
||||
if DEBUG:
|
||||
image.save(DATA_DIR / "texturesuv_debug_.png")
|
||||
|
||||
with Image.open(DATA_DIR / "texturesuv_debug.png") as image_ref_file:
|
||||
image_ref = np.array(image_ref_file)
|
||||
|
||||
self.assertClose(image_out, image_ref)
|
||||
|
||||
|
||||
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user