mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Applies new import merging and sorting from µsort v1.0. When merging imports, µsort will make a best-effort to move associated comments to match merged elements, but there are known limitations due to the diynamic nature of Python and developer tooling. These changes should not produce any dangerous runtime changes, but may require touch-ups to satisfy linters and other tooling. Note that µsort uses case-insensitive, lexicographical sorting, which results in a different ordering compared to isort. This provides a more consistent sorting order, matching the case-insensitive order used when sorting import statements by module name, and ensures that "frog", "FROG", and "Frog" always sort next to each other. For details on µsort's sorting and merging semantics, see the user guide: https://usort.readthedocs.io/en/stable/guide.html#sorting Reviewed By: bottler Differential Revision: D35553814 fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
468 lines
17 KiB
Python
468 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from common_testing import (
|
|
get_pytorch3d_dir,
|
|
get_random_cuda_device,
|
|
get_tests_dir,
|
|
TestCaseMixin,
|
|
)
|
|
from PIL import Image
|
|
from pytorch3d.io import load_objs_as_meshes
|
|
from pytorch3d.ops import sample_points_from_meshes
|
|
from pytorch3d.renderer import TexturesVertex
|
|
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
|
from pytorch3d.renderer.mesh.rasterize_meshes import barycentric_coordinates
|
|
from pytorch3d.renderer.points import (
|
|
NormWeightedCompositor,
|
|
PointsRasterizationSettings,
|
|
PointsRasterizer,
|
|
PointsRenderer,
|
|
)
|
|
from pytorch3d.structures import Meshes, Pointclouds
|
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
|
|
|
|
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
|
# All saved images have prefix DEBUG_
|
|
DEBUG = False
|
|
DATA_DIR = get_tests_dir() / "data"
|
|
|
|
|
|
def init_meshes(
|
|
num_meshes: int = 10,
|
|
num_verts: int = 1000,
|
|
num_faces: int = 3000,
|
|
device: str = "cpu",
|
|
add_texture: bool = False,
|
|
):
|
|
device = torch.device(device)
|
|
verts_list = []
|
|
faces_list = []
|
|
texts_list = []
|
|
for _ in range(num_meshes):
|
|
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
|
faces = torch.randint(
|
|
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
|
)
|
|
texts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
|
verts_list.append(verts)
|
|
faces_list.append(faces)
|
|
texts_list.append(texts)
|
|
|
|
# create textures
|
|
textures = None
|
|
if add_texture:
|
|
textures = TexturesVertex(texts_list)
|
|
meshes = Meshes(verts=verts_list, faces=faces_list, textures=textures)
|
|
|
|
return meshes
|
|
|
|
|
|
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(1)
|
|
|
|
def test_all_empty_meshes(self):
|
|
"""
|
|
Check sample_points_from_meshes raises an exception if all meshes are
|
|
invalid.
|
|
"""
|
|
device = get_random_cuda_device()
|
|
verts1 = torch.tensor([], dtype=torch.float32, device=device)
|
|
faces1 = torch.tensor([], dtype=torch.int64, device=device)
|
|
meshes = Meshes(verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1])
|
|
with self.assertRaises(ValueError) as err:
|
|
sample_points_from_meshes(meshes, num_samples=100, return_normals=True)
|
|
self.assertTrue("Meshes are empty." in str(err.exception))
|
|
|
|
def test_sampling_output(self):
|
|
"""
|
|
Check outputs of sampling are correct for different meshes.
|
|
For an ico_sphere, the sampled vertices should lie on a unit sphere.
|
|
For an empty mesh, the samples and normals should be 0.
|
|
"""
|
|
device = get_random_cuda_device()
|
|
|
|
# Unit simplex.
|
|
verts_pyramid = torch.tensor(
|
|
[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
faces_pyramid = torch.tensor(
|
|
[[0, 1, 2], [0, 2, 3], [0, 1, 3], [1, 2, 3]],
|
|
dtype=torch.int64,
|
|
device=device,
|
|
)
|
|
sphere_mesh = ico_sphere(9, device)
|
|
verts_sphere, faces_sphere = sphere_mesh.get_mesh_verts_faces(0)
|
|
verts_empty = torch.tensor([], dtype=torch.float32, device=device)
|
|
faces_empty = torch.tensor([], dtype=torch.int64, device=device)
|
|
num_samples = 10
|
|
meshes = Meshes(
|
|
verts=[verts_empty, verts_sphere, verts_pyramid],
|
|
faces=[faces_empty, faces_sphere, faces_pyramid],
|
|
)
|
|
samples, normals = sample_points_from_meshes(
|
|
meshes, num_samples=num_samples, return_normals=True
|
|
)
|
|
samples = samples.cpu()
|
|
normals = normals.cpu()
|
|
|
|
self.assertEqual(samples.shape, (3, num_samples, 3))
|
|
self.assertEqual(normals.shape, (3, num_samples, 3))
|
|
|
|
# Empty meshes: should have all zeros for samples and normals.
|
|
self.assertClose(samples[0, :], torch.zeros((num_samples, 3)))
|
|
self.assertClose(normals[0, :], torch.zeros((num_samples, 3)))
|
|
|
|
# Sphere: points should have radius 1.
|
|
x, y, z = samples[1, :].unbind(1)
|
|
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
|
|
|
|
self.assertClose(radius, torch.ones(num_samples))
|
|
|
|
# Pyramid: points shoudl lie on one of the faces.
|
|
pyramid_verts = samples[2, :]
|
|
pyramid_normals = normals[2, :]
|
|
|
|
self.assertClose(pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts))
|
|
self.assertClose((pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts))
|
|
|
|
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
|
|
face_1_idxs = pyramid_verts[:, 2] == 0
|
|
face_1_verts, face_1_normals = (
|
|
pyramid_verts[face_1_idxs, :],
|
|
pyramid_normals[face_1_idxs, :],
|
|
)
|
|
self.assertTrue(torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1))
|
|
self.assertClose(
|
|
face_1_normals,
|
|
torch.tensor([0, 0, 1], dtype=torch.float32).expand(face_1_normals.size()),
|
|
)
|
|
|
|
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
|
|
face_2_idxs = pyramid_verts[:, 0] == 0
|
|
face_2_verts, face_2_normals = (
|
|
pyramid_verts[face_2_idxs, :],
|
|
pyramid_normals[face_2_idxs, :],
|
|
)
|
|
self.assertTrue(torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1))
|
|
self.assertClose(
|
|
face_2_normals,
|
|
torch.tensor([1, 0, 0], dtype=torch.float32).expand(face_2_normals.size()),
|
|
)
|
|
|
|
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
|
|
face_3_idxs = pyramid_verts[:, 1] == 0
|
|
face_3_verts, face_3_normals = (
|
|
pyramid_verts[face_3_idxs, :],
|
|
pyramid_normals[face_3_idxs, :],
|
|
)
|
|
self.assertTrue(torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1))
|
|
self.assertClose(
|
|
face_3_normals,
|
|
torch.tensor([0, -1, 0], dtype=torch.float32).expand(face_3_normals.size()),
|
|
)
|
|
|
|
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
|
|
face_4_idxs = pyramid_verts.gt(0).all(1)
|
|
face_4_verts, face_4_normals = (
|
|
pyramid_verts[face_4_idxs, :],
|
|
pyramid_normals[face_4_idxs, :],
|
|
)
|
|
self.assertClose(face_4_verts.sum(1), torch.ones(face_4_verts.size(0)))
|
|
self.assertClose(
|
|
face_4_normals,
|
|
(
|
|
torch.tensor([1, 1, 1], dtype=torch.float32)
|
|
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
|
|
).expand(face_4_normals.size()),
|
|
)
|
|
|
|
def test_multinomial(self):
|
|
"""
|
|
Confirm that torch.multinomial does not sample elements which have
|
|
zero probability.
|
|
"""
|
|
freqs = torch.cuda.FloatTensor(
|
|
[
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.03178183361887932,
|
|
0.027680952101945877,
|
|
0.033176131546497345,
|
|
0.046052902936935425,
|
|
0.07742464542388916,
|
|
0.11543981730937958,
|
|
0.14148041605949402,
|
|
0.15784293413162231,
|
|
0.13180233538150787,
|
|
0.08271478116512299,
|
|
0.049702685326337814,
|
|
0.027557924389839172,
|
|
0.018125897273421288,
|
|
0.011851548217236996,
|
|
0.010252203792333603,
|
|
0.007422595750540495,
|
|
0.005372154992073774,
|
|
0.0045109698548913,
|
|
0.0036087757907807827,
|
|
0.0035267581697553396,
|
|
0.0018864056328311563,
|
|
0.0024605290964245796,
|
|
0.0022964938543736935,
|
|
0.0018453967059031129,
|
|
0.0010662291897460818,
|
|
0.0009842115687206388,
|
|
0.00045109697384759784,
|
|
0.0007791675161570311,
|
|
0.00020504408166743815,
|
|
0.00020504408166743815,
|
|
0.00020504408166743815,
|
|
0.00012302644609007984,
|
|
0.0,
|
|
0.00012302644609007984,
|
|
4.100881778867915e-05,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
0.0,
|
|
]
|
|
)
|
|
|
|
sample = []
|
|
for _ in range(1000):
|
|
torch.cuda.get_rng_state()
|
|
sample = torch.multinomial(freqs, 1000, True)
|
|
if freqs[sample].min() == 0:
|
|
sample_idx = (freqs[sample] == 0).nonzero()[0][0]
|
|
sampled = sample[sample_idx]
|
|
print(
|
|
"%s th element of last sample was %s, which has probability %s"
|
|
% (sample_idx, sampled, freqs[sampled])
|
|
)
|
|
return False
|
|
return True
|
|
|
|
def test_multinomial_weights(self):
|
|
"""
|
|
Confirm that torch.multinomial does not sample elements which have
|
|
zero probability using a real example of input from a training run.
|
|
"""
|
|
weights = torch.load(get_tests_dir() / "weights.pt")
|
|
S = 4096
|
|
num_trials = 100
|
|
for _ in range(0, num_trials):
|
|
weights[weights < 0] = 0.0
|
|
samples = weights.multinomial(S, replacement=True)
|
|
sampled_weights = weights[samples]
|
|
assert sampled_weights.min() > 0
|
|
if sampled_weights.min() <= 0:
|
|
return False
|
|
return True
|
|
|
|
def test_verts_nan(self):
|
|
num_verts = 30
|
|
num_faces = 50
|
|
for device in ["cpu", "cuda:0"]:
|
|
for invalid in ["nan", "inf"]:
|
|
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
|
# randomly assign an invalid type
|
|
verts[torch.randperm(num_verts)[:10]] = float(invalid)
|
|
faces = torch.randint(
|
|
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
|
)
|
|
meshes = Meshes(verts=[verts], faces=[faces])
|
|
|
|
with self.assertRaisesRegex(ValueError, "Meshes contain nan or inf."):
|
|
sample_points_from_meshes(
|
|
meshes, num_samples=100, return_normals=True
|
|
)
|
|
|
|
def test_outputs(self):
|
|
|
|
for add_texture in (True, False):
|
|
meshes = init_meshes(device=torch.device("cuda:0"), add_texture=add_texture)
|
|
out1 = sample_points_from_meshes(meshes, num_samples=100)
|
|
self.assertTrue(torch.is_tensor(out1))
|
|
|
|
out2 = sample_points_from_meshes(
|
|
meshes, num_samples=100, return_normals=True
|
|
)
|
|
self.assertTrue(isinstance(out2, tuple) and len(out2) == 2)
|
|
|
|
if add_texture:
|
|
out3 = sample_points_from_meshes(
|
|
meshes, num_samples=100, return_textures=True
|
|
)
|
|
self.assertTrue(isinstance(out3, tuple) and len(out3) == 2)
|
|
|
|
out4 = sample_points_from_meshes(
|
|
meshes, num_samples=100, return_normals=True, return_textures=True
|
|
)
|
|
self.assertTrue(isinstance(out4, tuple) and len(out4) == 3)
|
|
else:
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Meshes do not contain textures."
|
|
):
|
|
sample_points_from_meshes(
|
|
meshes, num_samples=100, return_textures=True
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Meshes do not contain textures."
|
|
):
|
|
sample_points_from_meshes(
|
|
meshes,
|
|
num_samples=100,
|
|
return_normals=True,
|
|
return_textures=True,
|
|
)
|
|
|
|
def test_texture_sampling(self):
|
|
device = torch.device("cuda:0")
|
|
batch_size = 6
|
|
# verts
|
|
verts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
|
|
verts[:, :3, 2] = 1.0
|
|
verts[:, 3:, 2] = -1.0
|
|
# textures
|
|
texts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
|
|
# faces
|
|
faces = torch.tensor([[0, 1, 2], [3, 4, 5]], device=device, dtype=torch.int64)
|
|
faces = faces.view(1, 2, 3).expand(batch_size, -1, -1)
|
|
|
|
meshes = Meshes(verts=verts, faces=faces, textures=TexturesVertex(texts))
|
|
|
|
num_samples = 24
|
|
samples, normals, textures = sample_points_from_meshes(
|
|
meshes, num_samples=num_samples, return_normals=True, return_textures=True
|
|
)
|
|
|
|
textures_naive = torch.zeros(
|
|
(batch_size, num_samples, 3), dtype=torch.float32, device=device
|
|
)
|
|
for n in range(batch_size):
|
|
for i in range(num_samples):
|
|
p = samples[n, i]
|
|
if p[2] > 0.0: # sampled from 1st face
|
|
v0, v1, v2 = verts[n, 0, :2], verts[n, 1, :2], verts[n, 2, :2]
|
|
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
|
|
t0, t1, t2 = texts[n, 0], texts[n, 1], texts[n, 2]
|
|
else: # sampled from 2nd face
|
|
v0, v1, v2 = verts[n, 3, :2], verts[n, 4, :2], verts[n, 5, :2]
|
|
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
|
|
t0, t1, t2 = texts[n, 3], texts[n, 4], texts[n, 5]
|
|
|
|
tt = w0 * t0 + w1 * t1 + w2 * t2
|
|
textures_naive[n, i] = tt
|
|
|
|
self.assertClose(textures, textures_naive)
|
|
|
|
def test_texture_sampling_cow(self):
|
|
# test texture sampling for the cow example by converting
|
|
# the cow mesh and its texture uv to a pointcloud with texture
|
|
|
|
device = torch.device("cuda:0")
|
|
obj_dir = get_pytorch3d_dir() / "docs/tutorials/data"
|
|
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
|
|
|
for text_type in ("uv", "atlas"):
|
|
# Load mesh + texture
|
|
if text_type == "uv":
|
|
mesh = load_objs_as_meshes(
|
|
[obj_filename], device=device, load_textures=True, texture_wrap=None
|
|
)
|
|
elif text_type == "atlas":
|
|
mesh = load_objs_as_meshes(
|
|
[obj_filename],
|
|
device=device,
|
|
load_textures=True,
|
|
create_texture_atlas=True,
|
|
texture_atlas_size=8,
|
|
texture_wrap=None,
|
|
)
|
|
|
|
points, normals, textures = sample_points_from_meshes(
|
|
mesh, num_samples=50000, return_normals=True, return_textures=True
|
|
)
|
|
pointclouds = Pointclouds(points, normals=normals, features=textures)
|
|
|
|
for pos in ("front", "back"):
|
|
# Init rasterizer settings
|
|
if pos == "back":
|
|
azim = 0.0
|
|
elif pos == "front":
|
|
azim = 180
|
|
R, T = look_at_view_transform(2.7, 0, azim)
|
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
|
|
|
raster_settings = PointsRasterizationSettings(
|
|
image_size=512, radius=1e-2, points_per_pixel=1
|
|
)
|
|
|
|
rasterizer = PointsRasterizer(
|
|
cameras=cameras, raster_settings=raster_settings
|
|
)
|
|
compositor = NormWeightedCompositor()
|
|
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
|
images = renderer(pointclouds)
|
|
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
if DEBUG:
|
|
filename = "DEBUG_cow_mesh_to_pointcloud_%s_%s.png" % (
|
|
text_type,
|
|
pos,
|
|
)
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / filename
|
|
)
|
|
|
|
@staticmethod
|
|
def sample_points_with_init(
|
|
num_meshes: int,
|
|
num_verts: int,
|
|
num_faces: int,
|
|
num_samples: int,
|
|
device: str = "cpu",
|
|
):
|
|
verts_list = []
|
|
faces_list = []
|
|
for _ in range(num_meshes):
|
|
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
|
faces = torch.randint(
|
|
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
|
)
|
|
verts_list.append(verts)
|
|
faces_list.append(faces)
|
|
meshes = Meshes(verts_list, faces_list)
|
|
torch.cuda.synchronize()
|
|
|
|
def sample_points():
|
|
sample_points_from_meshes(
|
|
meshes, num_samples=num_samples, return_normals=True
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
return sample_points
|