extend sample_points_from_meshes with texture

Summary:
Enhanced `sample_points_from_meshes` with texture sampling

* This new feature is used to return textures corresponding to the sampled points in `sample_points_from_meshes`

Reviewed By: nikhilaravi

Differential Revision: D24031525

fbshipit-source-id: 8e5d8f784cc38aa391aa8e84e54423bd9fad7ad1
This commit is contained in:
Georgia Gkioxari 2020-10-06 09:16:32 -07:00 committed by Facebook GitHub Bot
parent 5c9485c7be
commit 327bd2b976
2 changed files with 220 additions and 9 deletions

View File

@ -11,11 +11,19 @@ from typing import Tuple, Union
import torch import torch
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
from pytorch3d.ops.packed_to_padded import packed_to_padded from pytorch3d.ops.packed_to_padded import packed_to_padded
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments
def sample_points_from_meshes( def sample_points_from_meshes(
meshes, num_samples: int = 10000, return_normals: bool = False meshes,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: num_samples: int = 10000,
return_normals: bool = False,
return_textures: bool = False,
) -> Union[
torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
]:
""" """
Convert a batch of meshes to a pointcloud by uniformly sampling points on Convert a batch of meshes to a pointcloud by uniformly sampling points on
the surface of the mesh with probability proportional to the face area. the surface of the mesh with probability proportional to the face area.
@ -24,10 +32,10 @@ def sample_points_from_meshes(
meshes: A Meshes object with a batch of N meshes. meshes: A Meshes object with a batch of N meshes.
num_samples: Integer giving the number of point samples per mesh. num_samples: Integer giving the number of point samples per mesh.
return_normals: If True, return normals for the sampled points. return_normals: If True, return normals for the sampled points.
eps: (float) used to clamp the norm of the normals to avoid dividing by 0. return_textures: If True, return textures for the sampled points.
Returns: Returns:
2-element tuple containing 3-element tuple containing
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the - **samples**: FloatTensor of shape (N, num_samples, 3) giving the
coordinates of sampled points for each mesh in the batch. For empty coordinates of sampled points for each mesh in the batch. For empty
@ -36,6 +44,17 @@ def sample_points_from_meshes(
to each sampled point. Only returned if return_normals is True. to each sampled point. Only returned if return_normals is True.
For empty meshes the corresponding row in the normals array will For empty meshes the corresponding row in the normals array will
be filled with 0. be filled with 0.
- **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional
texture vector to each sampled point. Only returned if return_textures is True.
For empty meshes the corresponding row in the textures array will
be filled with 0.
Note that in a future releases, we will replace the 3-element tuple output
with a `Pointclouds` datastructure, as follows
.. code-block:: python
Poinclouds(samples, normals=normals, features=textures)
""" """
if meshes.isempty(): if meshes.isempty():
raise ValueError("Meshes are empty.") raise ValueError("Meshes are empty.")
@ -43,6 +62,10 @@ def sample_points_from_meshes(
verts = meshes.verts_packed() verts = meshes.verts_packed()
if not torch.isfinite(verts).all(): if not torch.isfinite(verts).all():
raise ValueError("Meshes contain nan or inf.") raise ValueError("Meshes contain nan or inf.")
if return_textures and meshes.textures is None:
raise ValueError("Meshes do not contain textures.")
faces = meshes.faces_packed() faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx() mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes) num_meshes = len(meshes)
@ -66,7 +89,7 @@ def sample_points_from_meshes(
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
# Get the vertex coordinates of the sampled faces. # Get the vertex coordinates of the sampled faces.
face_verts = verts[faces.long()] face_verts = verts[faces]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Randomly generate barycentric coords. # Randomly generate barycentric coords.
@ -92,9 +115,29 @@ def sample_points_from_meshes(
vert_normals = vert_normals[sample_face_idxs] vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals normals[meshes.valid] = vert_normals
if return_textures:
# fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1.
pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1
bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3
# zbuf and dists are not used in `sample_textures` so we initialize them with dummy
dummy = torch.zeros(
(len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32
) # NxSx1x1
fragments = MeshFragments(
pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy
)
textures = meshes.sample_textures(fragments) # NxSx1x1xC
textures = textures[:, :, 0, 0, :] # NxSxC
# return
# TODO(gkioxari) consider returning a Pointclouds instance [breaking]
if return_normals and return_textures:
return samples, normals, textures
if return_normals: # return_textures is False
return samples, normals return samples, normals
else: if return_textures: # return_normals is False
return samples return samples, textures
return samples
def _rand_barycentric_coords( def _rand_barycentric_coords(

View File

@ -4,13 +4,31 @@
import unittest import unittest
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
from common_testing import TestCaseMixin, get_random_cuda_device from common_testing import TestCaseMixin, get_random_cuda_device
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.renderer import TexturesVertex
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterize_meshes import barycentric_coordinates
from pytorch3d.renderer.points import (
NormWeightedCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestSamplePoints(TestCaseMixin, unittest.TestCase): class TestSamplePoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
@ -22,18 +40,27 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
num_verts: int = 1000, num_verts: int = 1000,
num_faces: int = 3000, num_faces: int = 3000,
device: str = "cpu", device: str = "cpu",
add_texture: bool = False,
): ):
device = torch.device(device) device = torch.device(device)
verts_list = [] verts_list = []
faces_list = [] faces_list = []
texts_list = []
for _ in range(num_meshes): for _ in range(num_meshes):
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device) verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
faces = torch.randint( faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
) )
texts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
verts_list.append(verts) verts_list.append(verts)
faces_list.append(faces) faces_list.append(faces)
meshes = Meshes(verts_list, faces_list) texts_list.append(texts)
# create textures
textures = None
if add_texture:
textures = TexturesVertex(texts_list)
meshes = Meshes(verts=verts_list, faces=faces_list, textures=textures)
return meshes return meshes
@ -264,6 +291,147 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
meshes, num_samples=100, return_normals=True meshes, num_samples=100, return_normals=True
) )
def test_outputs(self):
for add_texture in (True, False):
meshes = TestSamplePoints.init_meshes(
device=torch.device("cuda:0"), add_texture=add_texture
)
out1 = sample_points_from_meshes(meshes, num_samples=100)
self.assertTrue(torch.is_tensor(out1))
out2 = sample_points_from_meshes(
meshes, num_samples=100, return_normals=True
)
self.assertTrue(isinstance(out2, tuple) and len(out2) == 2)
if add_texture:
out3 = sample_points_from_meshes(
meshes, num_samples=100, return_textures=True
)
self.assertTrue(isinstance(out3, tuple) and len(out3) == 2)
out4 = sample_points_from_meshes(
meshes, num_samples=100, return_normals=True, return_textures=True
)
self.assertTrue(isinstance(out4, tuple) and len(out4) == 3)
else:
with self.assertRaisesRegex(
ValueError, "Meshes do not contain textures."
):
sample_points_from_meshes(
meshes, num_samples=100, return_textures=True
)
with self.assertRaisesRegex(
ValueError, "Meshes do not contain textures."
):
sample_points_from_meshes(
meshes,
num_samples=100,
return_normals=True,
return_textures=True,
)
def test_texture_sampling(self):
device = torch.device("cuda:0")
batch_size = 6
# verts
verts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
verts[:, :3, 2] = 1.0
verts[:, 3:, 2] = -1.0
# textures
texts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
# faces
faces = torch.tensor([[0, 1, 2], [3, 4, 5]], device=device, dtype=torch.int64)
faces = faces.view(1, 2, 3).expand(batch_size, -1, -1)
meshes = Meshes(verts=verts, faces=faces, textures=TexturesVertex(texts))
num_samples = 24
samples, normals, textures = sample_points_from_meshes(
meshes, num_samples=num_samples, return_normals=True, return_textures=True
)
textures_naive = torch.zeros(
(batch_size, num_samples, 3), dtype=torch.float32, device=device
)
for n in range(batch_size):
for i in range(num_samples):
p = samples[n, i]
if p[2] > 0.0: # sampled from 1st face
v0, v1, v2 = verts[n, 0, :2], verts[n, 1, :2], verts[n, 2, :2]
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
t0, t1, t2 = texts[n, 0], texts[n, 1], texts[n, 2]
else: # sampled from 2nd face
v0, v1, v2 = verts[n, 3, :2], verts[n, 4, :2], verts[n, 5, :2]
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
t0, t1, t2 = texts[n, 3], texts[n, 4], texts[n, 5]
tt = w0 * t0 + w1 * t1 + w2 * t2
textures_naive[n, i] = tt
self.assertClose(textures, textures_naive)
def test_texture_sampling_cow(self):
# test texture sampling for the cow example by converting
# the cow mesh and its texture uv to a pointcloud with texture
device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
for text_type in ("uv", "atlas"):
# Load mesh + texture
if text_type == "uv":
mesh = load_objs_as_meshes(
[obj_filename], device=device, load_textures=True, texture_wrap=None
)
elif text_type == "atlas":
mesh = load_objs_as_meshes(
[obj_filename],
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
points, normals, textures = sample_points_from_meshes(
mesh, num_samples=50000, return_normals=True, return_textures=True
)
pointclouds = Pointclouds(points, normals=normals, features=textures)
for pos in ("front", "back"):
# Init rasterizer settings
if pos == "back":
azim = 0.0
elif pos == "front":
azim = 180
R, T = look_at_view_transform(2.7, 0, azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=512, radius=1e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(
cameras=cameras, raster_settings=raster_settings
)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_cow_mesh_to_pointcloud_%s_%s.png" % (
text_type,
pos,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
@staticmethod @staticmethod
def sample_points_with_init( def sample_points_with_init(
num_meshes: int, num_meshes: int,