diff --git a/pytorch3d/ops/sample_points_from_meshes.py b/pytorch3d/ops/sample_points_from_meshes.py index 44350a65..9e6dac54 100644 --- a/pytorch3d/ops/sample_points_from_meshes.py +++ b/pytorch3d/ops/sample_points_from_meshes.py @@ -11,11 +11,19 @@ from typing import Tuple, Union import torch from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals from pytorch3d.ops.packed_to_padded import packed_to_padded +from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments def sample_points_from_meshes( - meshes, num_samples: int = 10000, return_normals: bool = False -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + meshes, + num_samples: int = 10000, + return_normals: bool = False, + return_textures: bool = False, +) -> Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], +]: """ Convert a batch of meshes to a pointcloud by uniformly sampling points on the surface of the mesh with probability proportional to the face area. @@ -24,10 +32,10 @@ def sample_points_from_meshes( meshes: A Meshes object with a batch of N meshes. num_samples: Integer giving the number of point samples per mesh. return_normals: If True, return normals for the sampled points. - eps: (float) used to clamp the norm of the normals to avoid dividing by 0. + return_textures: If True, return textures for the sampled points. Returns: - 2-element tuple containing + 3-element tuple containing - **samples**: FloatTensor of shape (N, num_samples, 3) giving the coordinates of sampled points for each mesh in the batch. For empty @@ -36,6 +44,17 @@ def sample_points_from_meshes( to each sampled point. Only returned if return_normals is True. For empty meshes the corresponding row in the normals array will be filled with 0. + - **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional + texture vector to each sampled point. Only returned if return_textures is True. + For empty meshes the corresponding row in the textures array will + be filled with 0. + + Note that in a future releases, we will replace the 3-element tuple output + with a `Pointclouds` datastructure, as follows + + .. code-block:: python + + Poinclouds(samples, normals=normals, features=textures) """ if meshes.isempty(): raise ValueError("Meshes are empty.") @@ -43,6 +62,10 @@ def sample_points_from_meshes( verts = meshes.verts_packed() if not torch.isfinite(verts).all(): raise ValueError("Meshes contain nan or inf.") + + if return_textures and meshes.textures is None: + raise ValueError("Meshes do not contain textures.") + faces = meshes.faces_packed() mesh_to_face = meshes.mesh_to_faces_packed_first_idx() num_meshes = len(meshes) @@ -66,7 +89,7 @@ def sample_points_from_meshes( sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) # Get the vertex coordinates of the sampled faces. - face_verts = verts[faces.long()] + face_verts = verts[faces] v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] # Randomly generate barycentric coords. @@ -92,9 +115,29 @@ def sample_points_from_meshes( vert_normals = vert_normals[sample_face_idxs] normals[meshes.valid] = vert_normals + if return_textures: + # fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1. + pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1 + bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3 + # zbuf and dists are not used in `sample_textures` so we initialize them with dummy + dummy = torch.zeros( + (len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32 + ) # NxSx1x1 + fragments = MeshFragments( + pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy + ) + textures = meshes.sample_textures(fragments) # NxSx1x1xC + textures = textures[:, :, 0, 0, :] # NxSxC + + # return + # TODO(gkioxari) consider returning a Pointclouds instance [breaking] + if return_normals and return_textures: + return samples, normals, textures + if return_normals: # return_textures is False return samples, normals - else: - return samples + if return_textures: # return_normals is False + return samples, textures + return samples def _rand_barycentric_coords( diff --git a/tests/test_sample_points_from_meshes.py b/tests/test_sample_points_from_meshes.py index c258515c..c0b49c4a 100644 --- a/tests/test_sample_points_from_meshes.py +++ b/tests/test_sample_points_from_meshes.py @@ -4,13 +4,31 @@ import unittest from pathlib import Path +import numpy as np import torch from common_testing import TestCaseMixin, get_random_cuda_device +from PIL import Image +from pytorch3d.io import load_objs_as_meshes from pytorch3d.ops import sample_points_from_meshes -from pytorch3d.structures.meshes import Meshes +from pytorch3d.renderer import TexturesVertex +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.mesh.rasterize_meshes import barycentric_coordinates +from pytorch3d.renderer.points import ( + NormWeightedCompositor, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, +) +from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.utils.ico_sphere import ico_sphere +# If DEBUG=True, save out images generated in the tests for debugging. +# All saved images have prefix DEBUG_ +DEBUG = False +DATA_DIR = Path(__file__).resolve().parent / "data" + + class TestSamplePoints(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -22,18 +40,27 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): num_verts: int = 1000, num_faces: int = 3000, device: str = "cpu", + add_texture: bool = False, ): device = torch.device(device) verts_list = [] faces_list = [] + texts_list = [] for _ in range(num_meshes): verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device) faces = torch.randint( num_verts, size=(num_faces, 3), dtype=torch.int64, device=device ) + texts = torch.rand((num_verts, 3), dtype=torch.float32, device=device) verts_list.append(verts) faces_list.append(faces) - meshes = Meshes(verts_list, faces_list) + texts_list.append(texts) + + # create textures + textures = None + if add_texture: + textures = TexturesVertex(texts_list) + meshes = Meshes(verts=verts_list, faces=faces_list, textures=textures) return meshes @@ -264,6 +291,147 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase): meshes, num_samples=100, return_normals=True ) + def test_outputs(self): + + for add_texture in (True, False): + meshes = TestSamplePoints.init_meshes( + device=torch.device("cuda:0"), add_texture=add_texture + ) + out1 = sample_points_from_meshes(meshes, num_samples=100) + self.assertTrue(torch.is_tensor(out1)) + + out2 = sample_points_from_meshes( + meshes, num_samples=100, return_normals=True + ) + self.assertTrue(isinstance(out2, tuple) and len(out2) == 2) + + if add_texture: + out3 = sample_points_from_meshes( + meshes, num_samples=100, return_textures=True + ) + self.assertTrue(isinstance(out3, tuple) and len(out3) == 2) + + out4 = sample_points_from_meshes( + meshes, num_samples=100, return_normals=True, return_textures=True + ) + self.assertTrue(isinstance(out4, tuple) and len(out4) == 3) + else: + with self.assertRaisesRegex( + ValueError, "Meshes do not contain textures." + ): + sample_points_from_meshes( + meshes, num_samples=100, return_textures=True + ) + + with self.assertRaisesRegex( + ValueError, "Meshes do not contain textures." + ): + sample_points_from_meshes( + meshes, + num_samples=100, + return_normals=True, + return_textures=True, + ) + + def test_texture_sampling(self): + device = torch.device("cuda:0") + batch_size = 6 + # verts + verts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32) + verts[:, :3, 2] = 1.0 + verts[:, 3:, 2] = -1.0 + # textures + texts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32) + # faces + faces = torch.tensor([[0, 1, 2], [3, 4, 5]], device=device, dtype=torch.int64) + faces = faces.view(1, 2, 3).expand(batch_size, -1, -1) + + meshes = Meshes(verts=verts, faces=faces, textures=TexturesVertex(texts)) + + num_samples = 24 + samples, normals, textures = sample_points_from_meshes( + meshes, num_samples=num_samples, return_normals=True, return_textures=True + ) + + textures_naive = torch.zeros( + (batch_size, num_samples, 3), dtype=torch.float32, device=device + ) + for n in range(batch_size): + for i in range(num_samples): + p = samples[n, i] + if p[2] > 0.0: # sampled from 1st face + v0, v1, v2 = verts[n, 0, :2], verts[n, 1, :2], verts[n, 2, :2] + w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2) + t0, t1, t2 = texts[n, 0], texts[n, 1], texts[n, 2] + else: # sampled from 2nd face + v0, v1, v2 = verts[n, 3, :2], verts[n, 4, :2], verts[n, 5, :2] + w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2) + t0, t1, t2 = texts[n, 3], texts[n, 4], texts[n, 5] + + tt = w0 * t0 + w1 * t1 + w2 * t2 + textures_naive[n, i] = tt + + self.assertClose(textures, textures_naive) + + def test_texture_sampling_cow(self): + # test texture sampling for the cow example by converting + # the cow mesh and its texture uv to a pointcloud with texture + + device = torch.device("cuda:0") + obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" + obj_filename = obj_dir / "cow_mesh/cow.obj" + + for text_type in ("uv", "atlas"): + # Load mesh + texture + if text_type == "uv": + mesh = load_objs_as_meshes( + [obj_filename], device=device, load_textures=True, texture_wrap=None + ) + elif text_type == "atlas": + mesh = load_objs_as_meshes( + [obj_filename], + device=device, + load_textures=True, + create_texture_atlas=True, + texture_atlas_size=8, + texture_wrap=None, + ) + + points, normals, textures = sample_points_from_meshes( + mesh, num_samples=50000, return_normals=True, return_textures=True + ) + pointclouds = Pointclouds(points, normals=normals, features=textures) + + for pos in ("front", "back"): + # Init rasterizer settings + if pos == "back": + azim = 0.0 + elif pos == "front": + azim = 180 + R, T = look_at_view_transform(2.7, 0, azim) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = PointsRasterizationSettings( + image_size=512, radius=1e-2, points_per_pixel=1 + ) + + rasterizer = PointsRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + compositor = NormWeightedCompositor() + renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor) + images = renderer(pointclouds) + + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_cow_mesh_to_pointcloud_%s_%s.png" % ( + text_type, + pos, + ) + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + @staticmethod def sample_points_with_init( num_meshes: int,