amalgamate meshes with texture into a single scene
Summary: Add a join_scene method to all the textures to allow the join_mesh function to include textures. Rename the join_mesh function to join_meshes_as_scene. For TexturesAtlas, we now interpolate if the user attempts to have the resolution vary across the batch. This doesn't look great if the resolution is already very low. For TexturesUV, a rectangle packing function is required, this does something simple. Reviewed By: gkioxari Differential Revision: D23188773 fbshipit-source-id: c013db061a04076e13e90ccc168a7913e933a9c5
BIN
tests/data/test_joinatlas_final.png
Normal file
|
After Width: | Height: | Size: 25 KiB |
BIN
tests/data/test_joinuvs0_final.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
tests/data/test_joinuvs0_map.png
Normal file
|
After Width: | Height: | Size: 807 B |
BIN
tests/data/test_joinuvs1_final.png
Normal file
|
After Width: | Height: | Size: 12 KiB |
BIN
tests/data/test_joinuvs1_map.png
Normal file
|
After Width: | Height: | Size: 819 B |
BIN
tests/data/test_joinuvs2_final.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
BIN
tests/data/test_joinuvs2_map.png
Normal file
|
After Width: | Height: | Size: 806 B |
BIN
tests/data/test_joinverts_final.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
@@ -33,7 +33,11 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch
|
||||
from pytorch3d.structures.meshes import (
|
||||
Meshes,
|
||||
join_meshes_as_batch,
|
||||
join_meshes_as_scene,
|
||||
)
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
from pytorch3d.utils.torus import torus
|
||||
|
||||
@@ -571,6 +575,288 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
|
||||
self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)
|
||||
|
||||
def test_join_uvs(self):
|
||||
"""Meshes with TexturesUV joined into a scene"""
|
||||
# Test the result of rendering three tori with separate textures.
|
||||
# The expected result is consistent with rendering them each alone.
|
||||
# This tests TexturesUV.join_scene with rectangle flipping,
|
||||
# and we check the form of the merged map as well.
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
R, T = look_at_view_transform(18, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=256, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
lights = PointLights(
|
||||
device=device,
|
||||
ambient_color=((1.0, 1.0, 1.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
)
|
||||
blend_params = BlendParams(
|
||||
sigma=1e-1,
|
||||
gamma=1e-4,
|
||||
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||
)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(
|
||||
device=device, blend_params=blend_params, cameras=cameras, lights=lights
|
||||
),
|
||||
)
|
||||
|
||||
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
|
||||
[verts] = plain_torus.verts_list()
|
||||
verts_shifted1 = verts.clone()
|
||||
verts_shifted1 *= 0.5
|
||||
verts_shifted1[:, 1] += 7
|
||||
verts_shifted2 = verts.clone()
|
||||
verts_shifted2 *= 0.5
|
||||
verts_shifted2[:, 1] -= 7
|
||||
|
||||
[faces] = plain_torus.faces_list()
|
||||
nocolor = torch.zeros((100, 100), device=device)
|
||||
color_gradient = torch.linspace(0, 1, steps=100, device=device)
|
||||
color_gradient1 = color_gradient[None].expand_as(nocolor)
|
||||
color_gradient2 = color_gradient[:, None].expand_as(nocolor)
|
||||
colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
|
||||
colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
|
||||
verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
|
||||
verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)
|
||||
|
||||
for i, align_corners, padding_mode in [
|
||||
(0, True, "border"),
|
||||
(1, False, "border"),
|
||||
(2, False, "zeros"),
|
||||
]:
|
||||
textures1 = TexturesUV(
|
||||
maps=[colors1],
|
||||
faces_uvs=[faces],
|
||||
verts_uvs=[verts_uvs1],
|
||||
align_corners=align_corners,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
|
||||
# These downsamplings of colors2 are chosen to ensure a flip and a non flip
|
||||
# when the maps are merged.
|
||||
# We have maps of size (100, 100), (50, 99) and (99, 50).
|
||||
textures2 = TexturesUV(
|
||||
maps=[colors2[::2, :-1]],
|
||||
faces_uvs=[faces],
|
||||
verts_uvs=[verts_uvs2],
|
||||
align_corners=align_corners,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
offset = torch.tensor([0, 0, 0.5], device=device)
|
||||
textures3 = TexturesUV(
|
||||
maps=[colors2[:-1, ::2] + offset],
|
||||
faces_uvs=[faces],
|
||||
verts_uvs=[verts_uvs2],
|
||||
align_corners=align_corners,
|
||||
padding_mode=padding_mode,
|
||||
)
|
||||
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
|
||||
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
|
||||
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
|
||||
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
|
||||
|
||||
output = renderer(mesh)[0, ..., :3].cpu()
|
||||
output1 = renderer(mesh1)[0, ..., :3].cpu()
|
||||
output2 = renderer(mesh2)[0, ..., :3].cpu()
|
||||
output3 = renderer(mesh3)[0, ..., :3].cpu()
|
||||
# The background color is white and the objects do not overlap, so we can
|
||||
# predict the merged image by taking the minimum over every channel
|
||||
merged = torch.min(torch.min(output1, output2), output3)
|
||||
|
||||
image_ref = load_rgb_image(f"test_joinuvs{i}_final.png", DATA_DIR)
|
||||
map_ref = load_rgb_image(f"test_joinuvs{i}_map.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / f"test_joinuvs{i}_final_.png"
|
||||
)
|
||||
Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / f"test_joinuvs{i}_merged.png"
|
||||
)
|
||||
|
||||
Image.fromarray((output1.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / f"test_joinuvs{i}_1.png"
|
||||
)
|
||||
Image.fromarray((output2.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / f"test_joinuvs{i}_2.png"
|
||||
)
|
||||
Image.fromarray((output3.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / f"test_joinuvs{i}_3.png"
|
||||
)
|
||||
Image.fromarray(
|
||||
(mesh.textures.maps_padded()[0].cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
).save(DATA_DIR / f"test_joinuvs{i}_map_.png")
|
||||
Image.fromarray(
|
||||
(mesh2.textures.maps_padded()[0].cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
).save(DATA_DIR / f"test_joinuvs{i}_map2.png")
|
||||
Image.fromarray(
|
||||
(mesh3.textures.maps_padded()[0].cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
).save(DATA_DIR / f"test_joinuvs{i}_map3.png")
|
||||
|
||||
self.assertClose(output, merged, atol=0.015)
|
||||
self.assertClose(output, image_ref, atol=0.05)
|
||||
self.assertClose(mesh.textures.maps_padded()[0].cpu(), map_ref, atol=0.05)
|
||||
|
||||
def test_join_verts(self):
|
||||
"""Meshes with TexturesVertex joined into a scene"""
|
||||
# Test the result of rendering two tori with separate textures.
|
||||
# The expected result is consistent with rendering them each alone.
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
|
||||
[verts] = plain_torus.verts_list()
|
||||
verts_shifted1 = verts.clone()
|
||||
verts_shifted1 *= 0.5
|
||||
verts_shifted1[:, 1] += 7
|
||||
|
||||
faces = plain_torus.faces_list()
|
||||
textures1 = TexturesVertex(verts_features=[torch.rand_like(verts)])
|
||||
textures2 = TexturesVertex(verts_features=[torch.rand_like(verts)])
|
||||
mesh1 = Meshes(verts=[verts], faces=faces, textures=textures1)
|
||||
mesh2 = Meshes(verts=[verts_shifted1], faces=faces, textures=textures2)
|
||||
mesh = join_meshes_as_scene([mesh1, mesh2])
|
||||
|
||||
R, T = look_at_view_transform(18, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=256, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
lights = PointLights(
|
||||
device=device,
|
||||
ambient_color=((1.0, 1.0, 1.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
)
|
||||
blend_params = BlendParams(
|
||||
sigma=1e-1,
|
||||
gamma=1e-4,
|
||||
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||
)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(
|
||||
device=device, blend_params=blend_params, cameras=cameras, lights=lights
|
||||
),
|
||||
)
|
||||
|
||||
output = renderer(mesh)
|
||||
|
||||
image_ref = load_rgb_image("test_joinverts_final.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
debugging_outputs = []
|
||||
for mesh_ in [mesh1, mesh2]:
|
||||
debugging_outputs.append(renderer(mesh_))
|
||||
Image.fromarray(
|
||||
(output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinverts_final_.png")
|
||||
Image.fromarray(
|
||||
(debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinverts_1.png")
|
||||
Image.fromarray(
|
||||
(debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinverts_2.png")
|
||||
|
||||
result = output[0, ..., :3].cpu()
|
||||
self.assertClose(result, image_ref, atol=0.05)
|
||||
|
||||
def test_join_atlas(self):
|
||||
"""Meshes with TexturesAtlas joined into a scene"""
|
||||
# Test the result of rendering two tori with separate textures.
|
||||
# The expected result is consistent with rendering them each alone.
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
|
||||
[verts] = plain_torus.verts_list()
|
||||
verts_shifted1 = verts.clone()
|
||||
verts_shifted1 *= 1.2
|
||||
verts_shifted1[:, 0] += 4
|
||||
verts_shifted1[:, 1] += 5
|
||||
verts[:, 0] -= 4
|
||||
verts[:, 1] -= 4
|
||||
|
||||
[faces] = plain_torus.faces_list()
|
||||
map_size = 3
|
||||
# Two random atlases.
|
||||
# The averaging of the random numbers here is not consistent with the
|
||||
# meaning of the atlases, but makes each face a bit smoother than
|
||||
# if everything had a random color.
|
||||
atlas1 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
|
||||
atlas1[:, 1] = 0.5 * atlas1[:, 0] + 0.5 * atlas1[:, 2]
|
||||
atlas1[:, :, 1] = 0.5 * atlas1[:, :, 0] + 0.5 * atlas1[:, :, 2]
|
||||
atlas2 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
|
||||
atlas2[:, 1] = 0.5 * atlas2[:, 0] + 0.5 * atlas2[:, 2]
|
||||
atlas2[:, :, 1] = 0.5 * atlas2[:, :, 0] + 0.5 * atlas2[:, :, 2]
|
||||
|
||||
textures1 = TexturesAtlas(atlas=[atlas1])
|
||||
textures2 = TexturesAtlas(atlas=[atlas2])
|
||||
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
|
||||
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
|
||||
mesh_joined = join_meshes_as_scene([mesh1, mesh2])
|
||||
|
||||
R, T = look_at_view_transform(18, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
lights = PointLights(
|
||||
device=device,
|
||||
ambient_color=((1.0, 1.0, 1.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
)
|
||||
blend_params = BlendParams(
|
||||
sigma=1e-1,
|
||||
gamma=1e-4,
|
||||
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||
)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(
|
||||
device=device, blend_params=blend_params, cameras=cameras, lights=lights
|
||||
),
|
||||
)
|
||||
|
||||
output = renderer(mesh_joined)
|
||||
|
||||
image_ref = load_rgb_image("test_joinatlas_final.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
debugging_outputs = []
|
||||
for mesh_ in [mesh1, mesh2]:
|
||||
debugging_outputs.append(renderer(mesh_))
|
||||
Image.fromarray(
|
||||
(output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinatlas_final_.png")
|
||||
Image.fromarray(
|
||||
(debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinatlas_1.png")
|
||||
Image.fromarray(
|
||||
(debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_joinatlas_2.png")
|
||||
|
||||
result = output[0, ..., :3].cpu()
|
||||
self.assertClose(result, image_ref, atol=0.05)
|
||||
|
||||
def test_joined_spheres(self):
|
||||
"""
|
||||
Test a list of Meshes can be joined as a single mesh and
|
||||
@@ -595,7 +881,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
sphere_mesh_list.append(
|
||||
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
|
||||
)
|
||||
joined_sphere_mesh = join_mesh(sphere_mesh_list)
|
||||
joined_sphere_mesh = join_meshes_as_scene(sphere_mesh_list)
|
||||
joined_sphere_mesh.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.textures import (
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
_list_to_padded_wrapper,
|
||||
pack_rectangles,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
|
||||
from test_meshes import TestMeshes
|
||||
@@ -730,3 +731,80 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(self, index, tex, meshes, source)
|
||||
tryindex(self, [2, 4], tex, meshes, source)
|
||||
|
||||
|
||||
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
|
||||
def wrap_pack(self, sizes):
|
||||
"""
|
||||
Call the pack_rectangles function, which we want to test,
|
||||
and return its outputs.
|
||||
Additionally makes some sanity checks on the output.
|
||||
"""
|
||||
res = pack_rectangles(sizes)
|
||||
total = res.total_size
|
||||
self.assertGreaterEqual(total[0], 0)
|
||||
self.assertGreaterEqual(total[1], 0)
|
||||
mask = torch.zeros(total, dtype=torch.bool)
|
||||
seen_x_bound = False
|
||||
seen_y_bound = False
|
||||
for (in_x, in_y), loc in zip(sizes, res.locations):
|
||||
self.assertGreaterEqual(loc[0], 0)
|
||||
self.assertGreaterEqual(loc[1], 0)
|
||||
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
|
||||
upper_x = placed_x + loc[0]
|
||||
upper_y = placed_y + loc[1]
|
||||
self.assertGreaterEqual(total[0], upper_x)
|
||||
if total[0] == upper_x:
|
||||
seen_x_bound = True
|
||||
self.assertGreaterEqual(total[1], upper_y)
|
||||
if total[1] == upper_y:
|
||||
seen_y_bound = True
|
||||
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
|
||||
self.assertEqual(already_taken, 0)
|
||||
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
|
||||
self.assertTrue(seen_x_bound)
|
||||
self.assertTrue(seen_y_bound)
|
||||
|
||||
self.assertTrue(torch.all(torch.sum(mask, dim=0, dtype=torch.int32) > 0))
|
||||
self.assertTrue(torch.all(torch.sum(mask, dim=1, dtype=torch.int32) > 0))
|
||||
return res
|
||||
|
||||
def assert_bb(self, sizes, expected):
|
||||
"""
|
||||
Apply the pack_rectangles function to sizes and verify the
|
||||
bounding box dimensions are expected.
|
||||
"""
|
||||
self.assertSetEqual(set(self.wrap_pack(sizes).total_size), expected)
|
||||
|
||||
def test_simple(self):
|
||||
self.assert_bb([(3, 4), (4, 3)], {6, 4})
|
||||
self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4})
|
||||
|
||||
# many squares
|
||||
self.assert_bb([(2, 2)] * 9, {2, 18})
|
||||
|
||||
# One big square and many small ones.
|
||||
self.assert_bb([(3, 3)] + [(1, 1)] * 2, {3, 4})
|
||||
self.assert_bb([(3, 3)] + [(1, 1)] * 3, {3, 4})
|
||||
self.assert_bb([(3, 3)] + [(1, 1)] * 4, {3, 5})
|
||||
self.assert_bb([(3, 3)] + [(1, 1)] * 5, {3, 5})
|
||||
self.assert_bb([(1, 1)] * 6 + [(3, 3)], {3, 5})
|
||||
self.assert_bb([(3, 3)] + [(1, 1)] * 7, {3, 6})
|
||||
|
||||
# many identical rectangles
|
||||
self.assert_bb([(7, 190)] * 4 + [(190, 7)] * 4, {190, 56})
|
||||
|
||||
# require placing the flipped version of a rectangle
|
||||
self.assert_bb([(1, 100), (5, 96), (4, 5)], {100, 6})
|
||||
|
||||
def test_random(self):
|
||||
for _ in range(5):
|
||||
vals = torch.randint(size=(20, 2), low=1, high=18)
|
||||
sizes = []
|
||||
for j in range(vals.shape[0]):
|
||||
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
|
||||
self.wrap_pack(sizes)
|
||||
|
||||