extend sample_points_from_meshes with texture

Summary:
Enhanced `sample_points_from_meshes` with texture sampling

* This new feature is used to return textures corresponding to the sampled points in `sample_points_from_meshes`

Reviewed By: nikhilaravi

Differential Revision: D24031525

fbshipit-source-id: 8e5d8f784cc38aa391aa8e84e54423bd9fad7ad1
This commit is contained in:
Georgia Gkioxari
2020-10-06 09:16:32 -07:00
committed by Facebook GitHub Bot
parent 5c9485c7be
commit 327bd2b976
2 changed files with 220 additions and 9 deletions

View File

@@ -4,13 +4,31 @@
import unittest
from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.renderer import TexturesVertex
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterize_meshes import barycentric_coordinates
from pytorch3d.renderer.points import (
NormWeightedCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
@@ -22,18 +40,27 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
num_verts: int = 1000,
num_faces: int = 3000,
device: str = "cpu",
add_texture: bool = False,
):
device = torch.device(device)
verts_list = []
faces_list = []
texts_list = []
for _ in range(num_meshes):
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
)
texts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts_list, faces_list)
texts_list.append(texts)
# create textures
textures = None
if add_texture:
textures = TexturesVertex(texts_list)
meshes = Meshes(verts=verts_list, faces=faces_list, textures=textures)
return meshes
@@ -264,6 +291,147 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
meshes, num_samples=100, return_normals=True
)
def test_outputs(self):
for add_texture in (True, False):
meshes = TestSamplePoints.init_meshes(
device=torch.device("cuda:0"), add_texture=add_texture
)
out1 = sample_points_from_meshes(meshes, num_samples=100)
self.assertTrue(torch.is_tensor(out1))
out2 = sample_points_from_meshes(
meshes, num_samples=100, return_normals=True
)
self.assertTrue(isinstance(out2, tuple) and len(out2) == 2)
if add_texture:
out3 = sample_points_from_meshes(
meshes, num_samples=100, return_textures=True
)
self.assertTrue(isinstance(out3, tuple) and len(out3) == 2)
out4 = sample_points_from_meshes(
meshes, num_samples=100, return_normals=True, return_textures=True
)
self.assertTrue(isinstance(out4, tuple) and len(out4) == 3)
else:
with self.assertRaisesRegex(
ValueError, "Meshes do not contain textures."
):
sample_points_from_meshes(
meshes, num_samples=100, return_textures=True
)
with self.assertRaisesRegex(
ValueError, "Meshes do not contain textures."
):
sample_points_from_meshes(
meshes,
num_samples=100,
return_normals=True,
return_textures=True,
)
def test_texture_sampling(self):
device = torch.device("cuda:0")
batch_size = 6
# verts
verts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
verts[:, :3, 2] = 1.0
verts[:, 3:, 2] = -1.0
# textures
texts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
# faces
faces = torch.tensor([[0, 1, 2], [3, 4, 5]], device=device, dtype=torch.int64)
faces = faces.view(1, 2, 3).expand(batch_size, -1, -1)
meshes = Meshes(verts=verts, faces=faces, textures=TexturesVertex(texts))
num_samples = 24
samples, normals, textures = sample_points_from_meshes(
meshes, num_samples=num_samples, return_normals=True, return_textures=True
)
textures_naive = torch.zeros(
(batch_size, num_samples, 3), dtype=torch.float32, device=device
)
for n in range(batch_size):
for i in range(num_samples):
p = samples[n, i]
if p[2] > 0.0: # sampled from 1st face
v0, v1, v2 = verts[n, 0, :2], verts[n, 1, :2], verts[n, 2, :2]
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
t0, t1, t2 = texts[n, 0], texts[n, 1], texts[n, 2]
else: # sampled from 2nd face
v0, v1, v2 = verts[n, 3, :2], verts[n, 4, :2], verts[n, 5, :2]
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
t0, t1, t2 = texts[n, 3], texts[n, 4], texts[n, 5]
tt = w0 * t0 + w1 * t1 + w2 * t2
textures_naive[n, i] = tt
self.assertClose(textures, textures_naive)
def test_texture_sampling_cow(self):
# test texture sampling for the cow example by converting
# the cow mesh and its texture uv to a pointcloud with texture
device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
for text_type in ("uv", "atlas"):
# Load mesh + texture
if text_type == "uv":
mesh = load_objs_as_meshes(
[obj_filename], device=device, load_textures=True, texture_wrap=None
)
elif text_type == "atlas":
mesh = load_objs_as_meshes(
[obj_filename],
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
points, normals, textures = sample_points_from_meshes(
mesh, num_samples=50000, return_normals=True, return_textures=True
)
pointclouds = Pointclouds(points, normals=normals, features=textures)
for pos in ("front", "back"):
# Init rasterizer settings
if pos == "back":
azim = 0.0
elif pos == "front":
azim = 180
R, T = look_at_view_transform(2.7, 0, azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=512, radius=1e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(
cameras=cameras, raster_settings=raster_settings
)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_cow_mesh_to_pointcloud_%s_%s.png" % (
text_type,
pos,
)
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
@staticmethod
def sample_points_with_init(
num_meshes: int,