amalgamate meshes with texture into a single scene

Summary:
Add a join_scene method to all the textures to allow the join_mesh function to include textures. Rename the join_mesh function to join_meshes_as_scene.

For TexturesAtlas, we now interpolate if the user attempts to have the resolution vary across the batch. This doesn't look great if the resolution is already very low.

For TexturesUV, a rectangle packing function is required, this does something simple.

Reviewed By: gkioxari

Differential Revision: D23188773

fbshipit-source-id: c013db061a04076e13e90ccc168a7913e933a9c5
This commit is contained in:
Jeremy Reizenstein 2020-08-25 11:26:58 -07:00 committed by Facebook GitHub Bot
parent e25ccab3d9
commit 909dc83505
14 changed files with 741 additions and 23 deletions

View File

@ -2,7 +2,7 @@
import itertools import itertools
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -10,6 +10,8 @@ from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from .utils import pack_rectangles
# This file contains classes and helper functions for texturing. # This file contains classes and helper functions for texturing.
# There are three types of textures: TexturesVertex, TexturesAtlas # There are three types of textures: TexturesVertex, TexturesAtlas
@ -329,6 +331,7 @@ class TexturesAtlas(TexturesBase):
[1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019 3D Reasoning', ICCV 2019
See also https://github.com/ShichenLiu/SoftRas/issues/21
""" """
if isinstance(atlas, (list, tuple)): if isinstance(atlas, (list, tuple)):
correct_format = all( correct_format = all(
@ -336,11 +339,15 @@ class TexturesAtlas(TexturesBase):
torch.is_tensor(elem) torch.is_tensor(elem)
and elem.ndim == 4 and elem.ndim == 4
and elem.shape[1] == elem.shape[2] and elem.shape[1] == elem.shape[2]
and elem.shape[1] == atlas[0].shape[1]
) )
for elem in atlas for elem in atlas
) )
if not correct_format: if not correct_format:
msg = "Expected atlas to be a list of tensors of shape (F, R, R, D)" msg = (
"Expected atlas to be a list of tensors of shape (F, R, R, D) "
"with the same value of R."
)
raise ValueError(msg) raise ValueError(msg)
self._atlas_list = atlas self._atlas_list = atlas
self._atlas_padded = None self._atlas_padded = None
@ -529,6 +536,12 @@ class TexturesAtlas(TexturesBase):
new_tex._num_faces_per_mesh = num_faces_per_mesh new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex return new_tex
def join_scene(self) -> "TexturesAtlas":
"""
Return a new TexturesAtlas amalgamating the batch.
"""
return self.__class__(atlas=[torch.cat(self.atlas_list())])
class TexturesUV(TexturesBase): class TexturesUV(TexturesBase):
def __init__( def __init__(
@ -560,7 +573,7 @@ class TexturesUV(TexturesBase):
the two align_corners options at the two align_corners options at
https://discuss.pytorch.org/t/22663/9 . https://discuss.pytorch.org/t/22663/9 .
An example of how the indexing into the maps, with align_corners=True An example of how the indexing into the maps, with align_corners=True,
works is as follows. works is as follows.
If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j] If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
@ -574,10 +587,11 @@ class TexturesUV(TexturesBase):
If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j] If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40]. whose color is given by maps[i][700, 40].
In this case, padding_mode even matters for values in verts_uvs When align_corners=False, padding_mode even matters for values in
slightly above 0 or slightly below 1. In this case, it matters if the verts_uvs slightly above 0 or slightly below 1. In this case, the
first value is outside the interval [0.0005, 0.9995] or if the second padding_mode matters if the first value is outside the interval
is outside the interval [0.005, 0.995]. [0.0005, 0.9995] or if the second is outside the interval
[0.005, 0.995].
""" """
super().__init__() super().__init__()
self.padding_mode = padding_mode self.padding_mode = padding_mode
@ -805,12 +819,9 @@ class TexturesUV(TexturesBase):
def maps_padded(self) -> torch.Tensor: def maps_padded(self) -> torch.Tensor:
return self._maps_padded return self._maps_padded
def maps_list(self) -> torch.Tensor: def maps_list(self) -> List[torch.Tensor]:
# maps_list is not used anywhere currently - maps if self._maps_list is not None:
# are padded to ensure the (H, W) of all maps is the return self._maps_list
# same across the batch and we don't store the
# unpadded sizes of the maps. Therefore just return
# the unbinded padded tensor.
return self._maps_padded.unbind(0) return self._maps_padded.unbind(0)
def extend(self, N: int) -> "TexturesUV": def extend(self, N: int) -> "TexturesUV":
@ -965,6 +976,143 @@ class TexturesUV(TexturesBase):
new_tex._num_faces_per_mesh = num_faces_per_mesh new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex return new_tex
def _place_map_into_single_map(
self,
single_map: torch.Tensor,
map_: torch.Tensor,
location: Tuple[int, int, bool], # (x,y) and whether flipped
) -> None:
"""
Copy map into a larger tensor single_map at the destination specified by location.
If align_corners is False, we add the needed border around the destination.
Used by join_scene.
Args:
single_map: (total_H, total_W, 3)
map_: (H, W, 3) source data
location: where to place map
"""
do_flip = location[2]
source = map_.transpose(0, 1) if do_flip else map_
border_width = 0 if self.align_corners else 1
lower_u = location[0] + border_width
lower_v = location[1] + border_width
upper_u = lower_u + source.shape[0]
upper_v = lower_v + source.shape[1]
single_map[lower_u:upper_u, lower_v:upper_v] = source
if self.padding_mode != "zeros" and not self.align_corners:
single_map[lower_u - 1, lower_v:upper_v] = single_map[
lower_u, lower_v:upper_v
]
single_map[upper_u, lower_v:upper_v] = single_map[
upper_u - 1, lower_v:upper_v
]
single_map[lower_u:upper_u, lower_v - 1] = single_map[
lower_u:upper_u, lower_v
]
single_map[lower_u:upper_u, upper_v] = single_map[
lower_u:upper_u, upper_v - 1
]
single_map[lower_u - 1, lower_v - 1] = single_map[lower_u, lower_v]
single_map[lower_u - 1, upper_v] = single_map[lower_u, upper_v - 1]
single_map[upper_u, lower_v - 1] = single_map[upper_u - 1, lower_v]
single_map[upper_u, upper_v] = single_map[upper_u - 1, upper_v - 1]
def join_scene(self) -> "TexturesUV":
"""
Return a new TexturesUV amalgamating the batch.
We calculate a large single map which contains the original maps,
and find verts_uvs to point into it. This will not replicate
behavior of padding for verts_uvs values outside [0,1].
If align_corners=False, we need to add an artificial border around
every map.
We use the function `pack_rectangles` to provide a layout for the
single map. _place_map_into_single_map is used to copy the maps
into the single map. The merging of verts_uvs and faces_uvs are
handled locally in this function.
"""
maps = self.maps_list()
heights_and_widths = []
extra_border = 0 if self.align_corners else 2
for map_ in maps:
heights_and_widths.append(
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
)
merging_plan = pack_rectangles(heights_and_widths)
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
verts_uvs = self.verts_uvs_list()
verts_uvs_merged = []
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
new_uvs = uvs.clone()
self._place_map_into_single_map(single_map, map_, loc)
do_flip = loc[2]
x_shape = map_.shape[1] if do_flip else map_.shape[0]
y_shape = map_.shape[0] if do_flip else map_.shape[1]
if do_flip:
# Here we have flipped / transposed the map.
# In uvs, the y values are decreasing from 1 to 0 and the x
# values increase from 0 to 1. We subtract all values from 1
# as the x's become y's and the y's become x's.
new_uvs = 1.0 - new_uvs[:, [1, 0]]
if TYPE_CHECKING:
new_uvs = torch.Tensor(new_uvs)
# If align_corners is True, then an index of x (where x is in
# the range 0 .. map_.shape[]-1) in one of the input maps
# was hit by a u of x/(map_.shape[]-1).
# That x is located at the index loc[] + x in the single_map, and
# to hit that we need u to equal (loc[] + x) / (total_size[]-1)
# so the old u should be mapped to
# { u*(map_.shape[]-1) + loc[] } / (total_size[]-1)
# If align_corners is False, then an index of x (where x is in
# the range 1 .. map_.shape[]-2) in one of the input maps
# was hit by a u of (x+0.5)/(map_.shape[]).
# That x is located at the index loc[] + 1 + x in the single_map,
# (where the 1 is for the border)
# and to hit that we need u to equal (loc[] + 1 + x + 0.5) / (total_size[])
# so the old u should be mapped to
# { loc[] + 1 + u*map_.shape[]-0.5 + 0.5 } / (total_size[])
# = { loc[] + 1 + u*map_.shape[] } / (total_size[])
# We change the y's in new_uvs for the scaling of height,
# and the x's for the scaling of width.
# That is why the 1's and 0's are mismatched in these lines.
one_if_align = 1 if self.align_corners else 0
one_if_not_align = 1 - one_if_align
denom_x = merging_plan.total_size[0] - one_if_align
scale_x = x_shape - one_if_align
denom_y = merging_plan.total_size[1] - one_if_align
scale_y = y_shape - one_if_align
new_uvs[:, 1] *= scale_x / denom_x
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
new_uvs[:, 0] *= scale_y / denom_y
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
verts_uvs_merged.append(new_uvs)
faces_uvs_merged = []
offset = 0
for faces_uvs_, verts_uvs_ in zip(self.faces_uvs_list(), verts_uvs):
faces_uvs_merged.append(offset + faces_uvs_)
offset += verts_uvs_.shape[0]
return self.__class__(
maps=[single_map],
verts_uvs=[torch.cat(verts_uvs_merged)],
faces_uvs=[torch.cat(faces_uvs_merged)],
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
class TexturesVertex(TexturesBase): class TexturesVertex(TexturesBase):
def __init__( def __init__(
@ -1156,3 +1304,9 @@ class TexturesVertex(TexturesBase):
new_tex = self.__class__(verts_features=verts_features_list) new_tex = self.__class__(verts_features=verts_features_list)
new_tex._num_verts_per_mesh = num_faces_per_mesh new_tex._num_verts_per_mesh = num_faces_per_mesh
return new_tex return new_tex
def join_scene(self) -> "TexturesVertex":
"""
Return a new TexturesVertex amalgamating the batch.
"""
return self.__class__(verts_features=[torch.cat(self.verts_features_list())])

View File

@ -1,6 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, NamedTuple, Tuple
import torch import torch
from pytorch3d.ops import interpolate_face_attributes from pytorch3d.ops import interpolate_face_attributes
@ -58,3 +60,184 @@ def _interpolate_zbuf(
] # (1, H, W, K) ] # (1, H, W, K)
zbuf[pix_to_face == -1] = -1 zbuf[pix_to_face == -1] = -1
return zbuf return zbuf
# ----------- Rectangle Packing -------------------- #
# Note the order of members matters here because it determines the queue order.
# We want to place longer rectangles first.
class _UnplacedRectangle(NamedTuple):
size: Tuple[int, int]
ind: int
flipped: bool
def _try_place_rectangle(
rect: _UnplacedRectangle,
placed_so_far: List[Tuple[int, int, bool]],
occupied: List[Tuple[int, int]],
) -> bool:
"""
Try to place rect within the current bounding box.
Part of the implementation of pack_rectangles.
Note that the arguments `placed_so_far` and `occupied` are modified.
Args:
rect: rectangle to place
placed_so_far: the locations decided upon so far - a list of
(x, y, whether flipped). The nth element is the
location of the nth rectangle if it has been decided.
(modified in place)
occupied: the nodes of the graph of extents of rightmost placed
rectangles - (modified in place)
Returns:
True on success.
Example:
(We always have placed the first rectangle horizontally and other
rectangles above it.)
Let's say the placed boxes 1-4 are layed out like this.
The coordinates of the points marked X are stored in occupied.
It is to the right of the X's that we seek to place rect.
+-----------------------X
|2 |
| +---X
| |4 |
| | |
| +---+X
| |3 |
| | |
+-----------------------+----+------X
y |1 |
^ | --->x |
| +-----------------------------------+
We want to place this rectangle.
+-+
|5|
| |
| | = rect
| |
| |
| |
+-+
The call will succeed, returning True, leaving us with
+-----------------------X
|2 | +-X
| +---+|5|
| |4 || |
| | || |
| +---++ |
| |3 | |
| | | |
+-----------------------+----+-+----X
|1 |
| |
+-----------------------------------+ .
"""
total_width = occupied[0][0]
needed_height = rect.size[1]
current_start_idx = None
current_max_width = 0
previous_height = 0
currently_packed = 0
for idx, interval in enumerate(occupied):
if interval[0] <= total_width - rect.size[0]:
currently_packed += interval[1] - previous_height
current_max_width = max(interval[0], current_max_width)
if current_start_idx is None:
current_start_idx = idx
if currently_packed >= needed_height:
current_max_width = max(interval[0], current_max_width)
placed_so_far[rect.ind] = (
current_max_width,
occupied[current_start_idx - 1][1],
rect.flipped,
)
new_occupied = (
current_max_width + rect.size[0],
occupied[current_start_idx - 1][1] + needed_height,
)
if currently_packed == needed_height:
occupied[idx] = new_occupied
del occupied[current_start_idx:idx]
elif idx > current_start_idx:
occupied[idx - 1] = new_occupied
del occupied[current_start_idx : (idx - 1)]
else:
occupied.insert(idx, new_occupied)
return True
else:
current_start_idx = None
current_max_width = 0
currently_packed = 0
previous_height = interval[1]
return False
class PackedRectangles(NamedTuple):
total_size: Tuple[int, int]
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
"""
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
a rectangle by 90 degrees) is allowed.
This is used to join several uv maps into a single scene, see
TexturesUV.join_scene.
Args:
sizes: List of sizes of rectangles to pack
Returns:
total_size: size of total large rectangle
rectangles: location for each of the input rectangles
"""
if len(sizes) < 2:
raise ValueError("Cannot pack less than two boxes")
queue = []
for i, size in enumerate(sizes):
if size[0] < size[1]:
queue.append(_UnplacedRectangle((size[1], size[0]), i, True))
else:
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
queue.sort()
placed_so_far = [(-1, -1, False)] * len(sizes)
biggest = queue.pop()
total_width, current_height = biggest.size
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
second = queue.pop()
placed_so_far[second.ind] = (0, current_height, second.flipped)
current_height += second.size[1]
occupied = [biggest.size, (second.size[0], current_height)]
for rect in reversed(queue):
if _try_place_rectangle(rect, placed_so_far, occupied):
continue
rotated = _UnplacedRectangle(
(rect.size[1], rect.size[0]), rect.ind, not rect.flipped
)
if _try_place_rectangle(rotated, placed_so_far, occupied):
continue
# rect wasn't placed in the current bounding box,
# so we add extra space to fit it in.
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
current_height += rect.size[1]
occupied.append((rect.size[0], current_height))
return PackedRectangles((total_width, current_height), placed_so_far)

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .meshes import Meshes, join_meshes_as_batch from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene
from .pointclouds import Pointclouds from .pointclouds import Pointclouds
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list

View File

@ -1254,7 +1254,7 @@ class Meshes(object):
""" """
verts_packed = self.verts_packed() verts_packed = self.verts_packed()
if vert_offsets_packed.shape != verts_packed.shape: if vert_offsets_packed.shape != verts_packed.shape:
raise ValueError("Verts offsets must have dimension (all_v, 2).") raise ValueError("Verts offsets must have dimension (all_v, 3).")
# update verts packed # update verts packed
self._verts_packed = verts_packed + vert_offsets_packed self._verts_packed = verts_packed + vert_offsets_packed
new_verts_list = list( new_verts_list = list(
@ -1548,26 +1548,43 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
return Meshes(verts=verts, faces=faces, textures=tex) return Meshes(verts=verts, faces=faces, textures=tex)
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes: def join_meshes_as_scene(
meshes: Union[Meshes, List[Meshes]], include_textures: bool = True
) -> Meshes:
""" """
Joins a batch of meshes in the form of a Meshes object or a list of Meshes Joins a batch of meshes in the form of a Meshes object or a list of Meshes
objects as a single mesh. If the input is a list, the Meshes objects in the list objects as a single mesh. If the input is a list, the Meshes objects in the
must all be on the same device. This version ignores all textures in the input meshes. list must all be on the same device. Unless include_textures is False, the
meshes must all have the same type of texture or must all not have textures.
If textures are included, then the textures are joined as a single scene in
addition to the meshes. For this, texture types have an appropriate method
called join_scene which joins mesh textures into a single texture.
If the textures are TexturesAtlas then they must have the same resolution.
If they are TexturesUV then they must have the same align_corners and
padding_mode. Values in verts_uvs outside [0, 1] will not
be respected.
Args: Args:
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects meshes: Meshes object that contains a batch of meshes, or a list of
Meshes objects.
include_textures: (bool) whether to try to join the textures.
Returns: Returns:
new Meshes object containing a single mesh new Meshes object containing a single mesh
""" """
if isinstance(meshes, List): if isinstance(meshes, List):
meshes = join_meshes_as_batch(meshes, include_textures=False) meshes = join_meshes_as_batch(meshes, include_textures=include_textures)
if len(meshes) == 1: if len(meshes) == 1:
return meshes return meshes
verts = meshes.verts_packed() # (sum(V_n), 3) verts = meshes.verts_packed() # (sum(V_n), 3)
# Offset automatically done by faces_packed # Offset automatically done by faces_packed
faces = meshes.faces_packed() # (sum(F_n), 3) faces = meshes.faces_packed() # (sum(F_n), 3)
textures = None
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0)) if include_textures and meshes.textures is not None:
textures = meshes.textures.join_scene()
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0), textures=textures)
return mesh return mesh

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 807 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 819 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 806 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -33,7 +33,11 @@ from pytorch3d.renderer.mesh.shader import (
SoftSilhouetteShader, SoftSilhouetteShader,
TexturedSoftPhongShader, TexturedSoftPhongShader,
) )
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch from pytorch3d.structures.meshes import (
Meshes,
join_meshes_as_batch,
join_meshes_as_scene,
)
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
from pytorch3d.utils.torus import torus from pytorch3d.utils.torus import torus
@ -571,6 +575,288 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5) self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5) self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)
def test_join_uvs(self):
"""Meshes with TexturesUV joined into a scene"""
# Test the result of rendering three tori with separate textures.
# The expected result is consistent with rendering them each alone.
# This tests TexturesUV.join_scene with rectangle flipping,
# and we check the form of the merged map as well.
torch.manual_seed(1)
device = torch.device("cuda:0")
R, T = look_at_view_transform(18, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
)
lights = PointLights(
device=device,
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
)
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(
device=device, blend_params=blend_params, cameras=cameras, lights=lights
),
)
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
[verts] = plain_torus.verts_list()
verts_shifted1 = verts.clone()
verts_shifted1 *= 0.5
verts_shifted1[:, 1] += 7
verts_shifted2 = verts.clone()
verts_shifted2 *= 0.5
verts_shifted2[:, 1] -= 7
[faces] = plain_torus.faces_list()
nocolor = torch.zeros((100, 100), device=device)
color_gradient = torch.linspace(0, 1, steps=100, device=device)
color_gradient1 = color_gradient[None].expand_as(nocolor)
color_gradient2 = color_gradient[:, None].expand_as(nocolor)
colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)
for i, align_corners, padding_mode in [
(0, True, "border"),
(1, False, "border"),
(2, False, "zeros"),
]:
textures1 = TexturesUV(
maps=[colors1],
faces_uvs=[faces],
verts_uvs=[verts_uvs1],
align_corners=align_corners,
padding_mode=padding_mode,
)
# These downsamplings of colors2 are chosen to ensure a flip and a non flip
# when the maps are merged.
# We have maps of size (100, 100), (50, 99) and (99, 50).
textures2 = TexturesUV(
maps=[colors2[::2, :-1]],
faces_uvs=[faces],
verts_uvs=[verts_uvs2],
align_corners=align_corners,
padding_mode=padding_mode,
)
offset = torch.tensor([0, 0, 0.5], device=device)
textures3 = TexturesUV(
maps=[colors2[:-1, ::2] + offset],
faces_uvs=[faces],
verts_uvs=[verts_uvs2],
align_corners=align_corners,
padding_mode=padding_mode,
)
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
output = renderer(mesh)[0, ..., :3].cpu()
output1 = renderer(mesh1)[0, ..., :3].cpu()
output2 = renderer(mesh2)[0, ..., :3].cpu()
output3 = renderer(mesh3)[0, ..., :3].cpu()
# The background color is white and the objects do not overlap, so we can
# predict the merged image by taking the minimum over every channel
merged = torch.min(torch.min(output1, output2), output3)
image_ref = load_rgb_image(f"test_joinuvs{i}_final.png", DATA_DIR)
map_ref = load_rgb_image(f"test_joinuvs{i}_map.png", DATA_DIR)
if DEBUG:
Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"test_joinuvs{i}_final_.png"
)
Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"test_joinuvs{i}_merged.png"
)
Image.fromarray((output1.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"test_joinuvs{i}_1.png"
)
Image.fromarray((output2.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"test_joinuvs{i}_2.png"
)
Image.fromarray((output3.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / f"test_joinuvs{i}_3.png"
)
Image.fromarray(
(mesh.textures.maps_padded()[0].cpu().numpy() * 255).astype(
np.uint8
)
).save(DATA_DIR / f"test_joinuvs{i}_map_.png")
Image.fromarray(
(mesh2.textures.maps_padded()[0].cpu().numpy() * 255).astype(
np.uint8
)
).save(DATA_DIR / f"test_joinuvs{i}_map2.png")
Image.fromarray(
(mesh3.textures.maps_padded()[0].cpu().numpy() * 255).astype(
np.uint8
)
).save(DATA_DIR / f"test_joinuvs{i}_map3.png")
self.assertClose(output, merged, atol=0.015)
self.assertClose(output, image_ref, atol=0.05)
self.assertClose(mesh.textures.maps_padded()[0].cpu(), map_ref, atol=0.05)
def test_join_verts(self):
"""Meshes with TexturesVertex joined into a scene"""
# Test the result of rendering two tori with separate textures.
# The expected result is consistent with rendering them each alone.
torch.manual_seed(1)
device = torch.device("cuda:0")
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
[verts] = plain_torus.verts_list()
verts_shifted1 = verts.clone()
verts_shifted1 *= 0.5
verts_shifted1[:, 1] += 7
faces = plain_torus.faces_list()
textures1 = TexturesVertex(verts_features=[torch.rand_like(verts)])
textures2 = TexturesVertex(verts_features=[torch.rand_like(verts)])
mesh1 = Meshes(verts=[verts], faces=faces, textures=textures1)
mesh2 = Meshes(verts=[verts_shifted1], faces=faces, textures=textures2)
mesh = join_meshes_as_scene([mesh1, mesh2])
R, T = look_at_view_transform(18, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
)
lights = PointLights(
device=device,
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
)
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(
device=device, blend_params=blend_params, cameras=cameras, lights=lights
),
)
output = renderer(mesh)
image_ref = load_rgb_image("test_joinverts_final.png", DATA_DIR)
if DEBUG:
debugging_outputs = []
for mesh_ in [mesh1, mesh2]:
debugging_outputs.append(renderer(mesh_))
Image.fromarray(
(output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinverts_final_.png")
Image.fromarray(
(debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinverts_1.png")
Image.fromarray(
(debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinverts_2.png")
result = output[0, ..., :3].cpu()
self.assertClose(result, image_ref, atol=0.05)
def test_join_atlas(self):
"""Meshes with TexturesAtlas joined into a scene"""
# Test the result of rendering two tori with separate textures.
# The expected result is consistent with rendering them each alone.
torch.manual_seed(1)
device = torch.device("cuda:0")
plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
[verts] = plain_torus.verts_list()
verts_shifted1 = verts.clone()
verts_shifted1 *= 1.2
verts_shifted1[:, 0] += 4
verts_shifted1[:, 1] += 5
verts[:, 0] -= 4
verts[:, 1] -= 4
[faces] = plain_torus.faces_list()
map_size = 3
# Two random atlases.
# The averaging of the random numbers here is not consistent with the
# meaning of the atlases, but makes each face a bit smoother than
# if everything had a random color.
atlas1 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
atlas1[:, 1] = 0.5 * atlas1[:, 0] + 0.5 * atlas1[:, 2]
atlas1[:, :, 1] = 0.5 * atlas1[:, :, 0] + 0.5 * atlas1[:, :, 2]
atlas2 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
atlas2[:, 1] = 0.5 * atlas2[:, 0] + 0.5 * atlas2[:, 2]
atlas2[:, :, 1] = 0.5 * atlas2[:, :, 0] + 0.5 * atlas2[:, :, 2]
textures1 = TexturesAtlas(atlas=[atlas1])
textures2 = TexturesAtlas(atlas=[atlas2])
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
mesh_joined = join_meshes_as_scene([mesh1, mesh2])
R, T = look_at_view_transform(18, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
lights = PointLights(
device=device,
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
)
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=HardPhongShader(
device=device, blend_params=blend_params, cameras=cameras, lights=lights
),
)
output = renderer(mesh_joined)
image_ref = load_rgb_image("test_joinatlas_final.png", DATA_DIR)
if DEBUG:
debugging_outputs = []
for mesh_ in [mesh1, mesh2]:
debugging_outputs.append(renderer(mesh_))
Image.fromarray(
(output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinatlas_final_.png")
Image.fromarray(
(debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinatlas_1.png")
Image.fromarray(
(debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
).save(DATA_DIR / "test_joinatlas_2.png")
result = output[0, ..., :3].cpu()
self.assertClose(result, image_ref, atol=0.05)
def test_joined_spheres(self): def test_joined_spheres(self):
""" """
Test a list of Meshes can be joined as a single mesh and Test a list of Meshes can be joined as a single mesh and
@ -595,7 +881,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
sphere_mesh_list.append( sphere_mesh_list.append(
Meshes(verts=verts, faces=sphere_list[i].faces_padded()) Meshes(verts=verts, faces=sphere_list[i].faces_padded())
) )
joined_sphere_mesh = join_mesh(sphere_mesh_list) joined_sphere_mesh = join_meshes_as_scene(sphere_mesh_list)
joined_sphere_mesh.textures = TexturesVertex( joined_sphere_mesh.textures = TexturesVertex(
verts_features=torch.ones_like(joined_sphere_mesh.verts_padded()) verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
) )

View File

@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.textures import (
TexturesUV, TexturesUV,
TexturesVertex, TexturesVertex,
_list_to_padded_wrapper, _list_to_padded_wrapper,
pack_rectangles,
) )
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from test_meshes import TestMeshes from test_meshes import TestMeshes
@ -730,3 +731,80 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
index = torch.tensor([1, 2], dtype=torch.int64) index = torch.tensor([1, 2], dtype=torch.int64)
tryindex(self, index, tex, meshes, source) tryindex(self, index, tex, meshes, source)
tryindex(self, [2, 4], tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source)
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def wrap_pack(self, sizes):
"""
Call the pack_rectangles function, which we want to test,
and return its outputs.
Additionally makes some sanity checks on the output.
"""
res = pack_rectangles(sizes)
total = res.total_size
self.assertGreaterEqual(total[0], 0)
self.assertGreaterEqual(total[1], 0)
mask = torch.zeros(total, dtype=torch.bool)
seen_x_bound = False
seen_y_bound = False
for (in_x, in_y), loc in zip(sizes, res.locations):
self.assertGreaterEqual(loc[0], 0)
self.assertGreaterEqual(loc[1], 0)
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
upper_x = placed_x + loc[0]
upper_y = placed_y + loc[1]
self.assertGreaterEqual(total[0], upper_x)
if total[0] == upper_x:
seen_x_bound = True
self.assertGreaterEqual(total[1], upper_y)
if total[1] == upper_y:
seen_y_bound = True
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
self.assertEqual(already_taken, 0)
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
self.assertTrue(seen_x_bound)
self.assertTrue(seen_y_bound)
self.assertTrue(torch.all(torch.sum(mask, dim=0, dtype=torch.int32) > 0))
self.assertTrue(torch.all(torch.sum(mask, dim=1, dtype=torch.int32) > 0))
return res
def assert_bb(self, sizes, expected):
"""
Apply the pack_rectangles function to sizes and verify the
bounding box dimensions are expected.
"""
self.assertSetEqual(set(self.wrap_pack(sizes).total_size), expected)
def test_simple(self):
self.assert_bb([(3, 4), (4, 3)], {6, 4})
self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4})
# many squares
self.assert_bb([(2, 2)] * 9, {2, 18})
# One big square and many small ones.
self.assert_bb([(3, 3)] + [(1, 1)] * 2, {3, 4})
self.assert_bb([(3, 3)] + [(1, 1)] * 3, {3, 4})
self.assert_bb([(3, 3)] + [(1, 1)] * 4, {3, 5})
self.assert_bb([(3, 3)] + [(1, 1)] * 5, {3, 5})
self.assert_bb([(1, 1)] * 6 + [(3, 3)], {3, 5})
self.assert_bb([(3, 3)] + [(1, 1)] * 7, {3, 6})
# many identical rectangles
self.assert_bb([(7, 190)] * 4 + [(190, 7)] * 4, {190, 56})
# require placing the flipped version of a rectangle
self.assert_bb([(1, 100), (5, 96), (4, 5)], {100, 6})
def test_random(self):
for _ in range(5):
vals = torch.randint(size=(20, 2), low=1, high=18)
sizes = []
for j in range(vals.shape[0]):
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
self.wrap_pack(sizes)