diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 0d69450f..cccb3d75 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -2,7 +2,7 @@ import itertools import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -10,6 +10,8 @@ from pytorch3d.ops import interpolate_face_attributes from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list from torch.nn.functional import interpolate +from .utils import pack_rectangles + # This file contains classes and helper functions for texturing. # There are three types of textures: TexturesVertex, TexturesAtlas @@ -329,6 +331,7 @@ class TexturesAtlas(TexturesBase): [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning', ICCV 2019 + See also https://github.com/ShichenLiu/SoftRas/issues/21 """ if isinstance(atlas, (list, tuple)): correct_format = all( @@ -336,11 +339,15 @@ class TexturesAtlas(TexturesBase): torch.is_tensor(elem) and elem.ndim == 4 and elem.shape[1] == elem.shape[2] + and elem.shape[1] == atlas[0].shape[1] ) for elem in atlas ) if not correct_format: - msg = "Expected atlas to be a list of tensors of shape (F, R, R, D)" + msg = ( + "Expected atlas to be a list of tensors of shape (F, R, R, D) " + "with the same value of R." + ) raise ValueError(msg) self._atlas_list = atlas self._atlas_padded = None @@ -529,6 +536,12 @@ class TexturesAtlas(TexturesBase): new_tex._num_faces_per_mesh = num_faces_per_mesh return new_tex + def join_scene(self) -> "TexturesAtlas": + """ + Return a new TexturesAtlas amalgamating the batch. + """ + return self.__class__(atlas=[torch.cat(self.atlas_list())]) + class TexturesUV(TexturesBase): def __init__( @@ -560,7 +573,7 @@ class TexturesUV(TexturesBase): the two align_corners options at https://discuss.pytorch.org/t/22663/9 . - An example of how the indexing into the maps, with align_corners=True + An example of how the indexing into the maps, with align_corners=True, works is as follows. If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j] is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex @@ -574,10 +587,11 @@ class TexturesUV(TexturesBase): If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j] is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex whose color is given by maps[i][700, 40]. - In this case, padding_mode even matters for values in verts_uvs - slightly above 0 or slightly below 1. In this case, it matters if the - first value is outside the interval [0.0005, 0.9995] or if the second - is outside the interval [0.005, 0.995]. + When align_corners=False, padding_mode even matters for values in + verts_uvs slightly above 0 or slightly below 1. In this case, the + padding_mode matters if the first value is outside the interval + [0.0005, 0.9995] or if the second is outside the interval + [0.005, 0.995]. """ super().__init__() self.padding_mode = padding_mode @@ -805,12 +819,9 @@ class TexturesUV(TexturesBase): def maps_padded(self) -> torch.Tensor: return self._maps_padded - def maps_list(self) -> torch.Tensor: - # maps_list is not used anywhere currently - maps - # are padded to ensure the (H, W) of all maps is the - # same across the batch and we don't store the - # unpadded sizes of the maps. Therefore just return - # the unbinded padded tensor. + def maps_list(self) -> List[torch.Tensor]: + if self._maps_list is not None: + return self._maps_list return self._maps_padded.unbind(0) def extend(self, N: int) -> "TexturesUV": @@ -965,6 +976,143 @@ class TexturesUV(TexturesBase): new_tex._num_faces_per_mesh = num_faces_per_mesh return new_tex + def _place_map_into_single_map( + self, + single_map: torch.Tensor, + map_: torch.Tensor, + location: Tuple[int, int, bool], # (x,y) and whether flipped + ) -> None: + """ + Copy map into a larger tensor single_map at the destination specified by location. + If align_corners is False, we add the needed border around the destination. + + Used by join_scene. + + Args: + single_map: (total_H, total_W, 3) + map_: (H, W, 3) source data + location: where to place map + """ + do_flip = location[2] + source = map_.transpose(0, 1) if do_flip else map_ + border_width = 0 if self.align_corners else 1 + lower_u = location[0] + border_width + lower_v = location[1] + border_width + upper_u = lower_u + source.shape[0] + upper_v = lower_v + source.shape[1] + single_map[lower_u:upper_u, lower_v:upper_v] = source + + if self.padding_mode != "zeros" and not self.align_corners: + single_map[lower_u - 1, lower_v:upper_v] = single_map[ + lower_u, lower_v:upper_v + ] + single_map[upper_u, lower_v:upper_v] = single_map[ + upper_u - 1, lower_v:upper_v + ] + single_map[lower_u:upper_u, lower_v - 1] = single_map[ + lower_u:upper_u, lower_v + ] + single_map[lower_u:upper_u, upper_v] = single_map[ + lower_u:upper_u, upper_v - 1 + ] + single_map[lower_u - 1, lower_v - 1] = single_map[lower_u, lower_v] + single_map[lower_u - 1, upper_v] = single_map[lower_u, upper_v - 1] + single_map[upper_u, lower_v - 1] = single_map[upper_u - 1, lower_v] + single_map[upper_u, upper_v] = single_map[upper_u - 1, upper_v - 1] + + def join_scene(self) -> "TexturesUV": + """ + Return a new TexturesUV amalgamating the batch. + + We calculate a large single map which contains the original maps, + and find verts_uvs to point into it. This will not replicate + behavior of padding for verts_uvs values outside [0,1]. + + If align_corners=False, we need to add an artificial border around + every map. + + We use the function `pack_rectangles` to provide a layout for the + single map. _place_map_into_single_map is used to copy the maps + into the single map. The merging of verts_uvs and faces_uvs are + handled locally in this function. + """ + maps = self.maps_list() + heights_and_widths = [] + extra_border = 0 if self.align_corners else 2 + for map_ in maps: + heights_and_widths.append( + (map_.shape[0] + extra_border, map_.shape[1] + extra_border) + ) + merging_plan = pack_rectangles(heights_and_widths) + # pyre-fixme[16]: `Tensor` has no attribute `new_zeros`. + single_map = maps[0].new_zeros((*merging_plan.total_size, 3)) + verts_uvs = self.verts_uvs_list() + verts_uvs_merged = [] + + for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs): + new_uvs = uvs.clone() + self._place_map_into_single_map(single_map, map_, loc) + do_flip = loc[2] + x_shape = map_.shape[1] if do_flip else map_.shape[0] + y_shape = map_.shape[0] if do_flip else map_.shape[1] + + if do_flip: + # Here we have flipped / transposed the map. + # In uvs, the y values are decreasing from 1 to 0 and the x + # values increase from 0 to 1. We subtract all values from 1 + # as the x's become y's and the y's become x's. + new_uvs = 1.0 - new_uvs[:, [1, 0]] + if TYPE_CHECKING: + new_uvs = torch.Tensor(new_uvs) + + # If align_corners is True, then an index of x (where x is in + # the range 0 .. map_.shape[]-1) in one of the input maps + # was hit by a u of x/(map_.shape[]-1). + # That x is located at the index loc[] + x in the single_map, and + # to hit that we need u to equal (loc[] + x) / (total_size[]-1) + # so the old u should be mapped to + # { u*(map_.shape[]-1) + loc[] } / (total_size[]-1) + + # If align_corners is False, then an index of x (where x is in + # the range 1 .. map_.shape[]-2) in one of the input maps + # was hit by a u of (x+0.5)/(map_.shape[]). + # That x is located at the index loc[] + 1 + x in the single_map, + # (where the 1 is for the border) + # and to hit that we need u to equal (loc[] + 1 + x + 0.5) / (total_size[]) + # so the old u should be mapped to + # { loc[] + 1 + u*map_.shape[]-0.5 + 0.5 } / (total_size[]) + # = { loc[] + 1 + u*map_.shape[] } / (total_size[]) + + # We change the y's in new_uvs for the scaling of height, + # and the x's for the scaling of width. + # That is why the 1's and 0's are mismatched in these lines. + one_if_align = 1 if self.align_corners else 0 + one_if_not_align = 1 - one_if_align + denom_x = merging_plan.total_size[0] - one_if_align + scale_x = x_shape - one_if_align + denom_y = merging_plan.total_size[1] - one_if_align + scale_y = y_shape - one_if_align + new_uvs[:, 1] *= scale_x / denom_x + new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x + new_uvs[:, 0] *= scale_y / denom_y + new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y + + verts_uvs_merged.append(new_uvs) + + faces_uvs_merged = [] + offset = 0 + for faces_uvs_, verts_uvs_ in zip(self.faces_uvs_list(), verts_uvs): + faces_uvs_merged.append(offset + faces_uvs_) + offset += verts_uvs_.shape[0] + + return self.__class__( + maps=[single_map], + verts_uvs=[torch.cat(verts_uvs_merged)], + faces_uvs=[torch.cat(faces_uvs_merged)], + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ) + class TexturesVertex(TexturesBase): def __init__( @@ -1156,3 +1304,9 @@ class TexturesVertex(TexturesBase): new_tex = self.__class__(verts_features=verts_features_list) new_tex._num_verts_per_mesh = num_faces_per_mesh return new_tex + + def join_scene(self) -> "TexturesVertex": + """ + Return a new TexturesVertex amalgamating the batch. + """ + return self.__class__(verts_features=[torch.cat(self.verts_features_list())]) diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py index 749a746d..81ce0f1f 100644 --- a/pytorch3d/renderer/mesh/utils.py +++ b/pytorch3d/renderer/mesh/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import List, NamedTuple, Tuple + import torch from pytorch3d.ops import interpolate_face_attributes @@ -58,3 +60,184 @@ def _interpolate_zbuf( ] # (1, H, W, K) zbuf[pix_to_face == -1] = -1 return zbuf + + +# ----------- Rectangle Packing -------------------- # + +# Note the order of members matters here because it determines the queue order. +# We want to place longer rectangles first. +class _UnplacedRectangle(NamedTuple): + size: Tuple[int, int] + ind: int + flipped: bool + + +def _try_place_rectangle( + rect: _UnplacedRectangle, + placed_so_far: List[Tuple[int, int, bool]], + occupied: List[Tuple[int, int]], +) -> bool: + """ + Try to place rect within the current bounding box. + Part of the implementation of pack_rectangles. + + Note that the arguments `placed_so_far` and `occupied` are modified. + + Args: + rect: rectangle to place + placed_so_far: the locations decided upon so far - a list of + (x, y, whether flipped). The nth element is the + location of the nth rectangle if it has been decided. + (modified in place) + occupied: the nodes of the graph of extents of rightmost placed + rectangles - (modified in place) + + Returns: + True on success. + + Example: + (We always have placed the first rectangle horizontally and other + rectangles above it.) + Let's say the placed boxes 1-4 are layed out like this. + The coordinates of the points marked X are stored in occupied. + It is to the right of the X's that we seek to place rect. + + +-----------------------X + |2 | + | +---X + | |4 | + | | | + | +---+X + | |3 | + | | | + +-----------------------+----+------X +y |1 | +^ | --->x | +| +-----------------------------------+ + + We want to place this rectangle. + + +-+ + |5| + | | + | | = rect + | | + | | + | | + +-+ + + The call will succeed, returning True, leaving us with + + +-----------------------X + |2 | +-X + | +---+|5| + | |4 || | + | | || | + | +---++ | + | |3 | | + | | | | + +-----------------------+----+-+----X + |1 | + | | + +-----------------------------------+ . + + """ + total_width = occupied[0][0] + needed_height = rect.size[1] + current_start_idx = None + current_max_width = 0 + previous_height = 0 + currently_packed = 0 + for idx, interval in enumerate(occupied): + if interval[0] <= total_width - rect.size[0]: + currently_packed += interval[1] - previous_height + current_max_width = max(interval[0], current_max_width) + if current_start_idx is None: + current_start_idx = idx + if currently_packed >= needed_height: + current_max_width = max(interval[0], current_max_width) + placed_so_far[rect.ind] = ( + current_max_width, + occupied[current_start_idx - 1][1], + rect.flipped, + ) + new_occupied = ( + current_max_width + rect.size[0], + occupied[current_start_idx - 1][1] + needed_height, + ) + if currently_packed == needed_height: + occupied[idx] = new_occupied + del occupied[current_start_idx:idx] + elif idx > current_start_idx: + occupied[idx - 1] = new_occupied + del occupied[current_start_idx : (idx - 1)] + else: + occupied.insert(idx, new_occupied) + return True + else: + current_start_idx = None + current_max_width = 0 + currently_packed = 0 + previous_height = interval[1] + return False + + +class PackedRectangles(NamedTuple): + total_size: Tuple[int, int] + locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped + + +def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles: + """ + Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating + a rectangle by 90 degrees) is allowed. + + This is used to join several uv maps into a single scene, see + TexturesUV.join_scene. + + Args: + sizes: List of sizes of rectangles to pack + + Returns: + total_size: size of total large rectangle + rectangles: location for each of the input rectangles + """ + + if len(sizes) < 2: + raise ValueError("Cannot pack less than two boxes") + + queue = [] + for i, size in enumerate(sizes): + if size[0] < size[1]: + queue.append(_UnplacedRectangle((size[1], size[0]), i, True)) + else: + queue.append(_UnplacedRectangle((size[0], size[1]), i, False)) + queue.sort() + placed_so_far = [(-1, -1, False)] * len(sizes) + + biggest = queue.pop() + total_width, current_height = biggest.size + placed_so_far[biggest.ind] = (0, 0, biggest.flipped) + + second = queue.pop() + placed_so_far[second.ind] = (0, current_height, second.flipped) + current_height += second.size[1] + occupied = [biggest.size, (second.size[0], current_height)] + + for rect in reversed(queue): + if _try_place_rectangle(rect, placed_so_far, occupied): + continue + + rotated = _UnplacedRectangle( + (rect.size[1], rect.size[0]), rect.ind, not rect.flipped + ) + if _try_place_rectangle(rotated, placed_so_far, occupied): + continue + + # rect wasn't placed in the current bounding box, + # so we add extra space to fit it in. + placed_so_far[rect.ind] = (0, current_height, rect.flipped) + current_height += rect.size[1] + occupied.append((rect.size[0], current_height)) + + return PackedRectangles((total_width, current_height), placed_so_far) diff --git a/pytorch3d/structures/__init__.py b/pytorch3d/structures/__init__.py index 78d24a26..e83db39e 100644 --- a/pytorch3d/structures/__init__.py +++ b/pytorch3d/structures/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .meshes import Meshes, join_meshes_as_batch +from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene from .pointclouds import Pointclouds from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index c621b6a7..42a1ed81 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1254,7 +1254,7 @@ class Meshes(object): """ verts_packed = self.verts_packed() if vert_offsets_packed.shape != verts_packed.shape: - raise ValueError("Verts offsets must have dimension (all_v, 2).") + raise ValueError("Verts offsets must have dimension (all_v, 3).") # update verts packed self._verts_packed = verts_packed + vert_offsets_packed new_verts_list = list( @@ -1548,26 +1548,43 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): return Meshes(verts=verts, faces=faces, textures=tex) -def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes: +def join_meshes_as_scene( + meshes: Union[Meshes, List[Meshes]], include_textures: bool = True +) -> Meshes: """ Joins a batch of meshes in the form of a Meshes object or a list of Meshes - objects as a single mesh. If the input is a list, the Meshes objects in the list - must all be on the same device. This version ignores all textures in the input meshes. + objects as a single mesh. If the input is a list, the Meshes objects in the + list must all be on the same device. Unless include_textures is False, the + meshes must all have the same type of texture or must all not have textures. + + If textures are included, then the textures are joined as a single scene in + addition to the meshes. For this, texture types have an appropriate method + called join_scene which joins mesh textures into a single texture. + If the textures are TexturesAtlas then they must have the same resolution. + If they are TexturesUV then they must have the same align_corners and + padding_mode. Values in verts_uvs outside [0, 1] will not + be respected. Args: - meshes: Meshes object that contains a batch of meshes or a list of Meshes objects + meshes: Meshes object that contains a batch of meshes, or a list of + Meshes objects. + include_textures: (bool) whether to try to join the textures. Returns: new Meshes object containing a single mesh """ if isinstance(meshes, List): - meshes = join_meshes_as_batch(meshes, include_textures=False) + meshes = join_meshes_as_batch(meshes, include_textures=include_textures) if len(meshes) == 1: return meshes verts = meshes.verts_packed() # (sum(V_n), 3) # Offset automatically done by faces_packed faces = meshes.faces_packed() # (sum(F_n), 3) + textures = None - mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0)) + if include_textures and meshes.textures is not None: + textures = meshes.textures.join_scene() + + mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0), textures=textures) return mesh diff --git a/tests/data/test_joinatlas_final.png b/tests/data/test_joinatlas_final.png new file mode 100644 index 00000000..3f760555 Binary files /dev/null and b/tests/data/test_joinatlas_final.png differ diff --git a/tests/data/test_joinuvs0_final.png b/tests/data/test_joinuvs0_final.png new file mode 100644 index 00000000..3394f7f5 Binary files /dev/null and b/tests/data/test_joinuvs0_final.png differ diff --git a/tests/data/test_joinuvs0_map.png b/tests/data/test_joinuvs0_map.png new file mode 100644 index 00000000..163afcf6 Binary files /dev/null and b/tests/data/test_joinuvs0_map.png differ diff --git a/tests/data/test_joinuvs1_final.png b/tests/data/test_joinuvs1_final.png new file mode 100644 index 00000000..b624ef75 Binary files /dev/null and b/tests/data/test_joinuvs1_final.png differ diff --git a/tests/data/test_joinuvs1_map.png b/tests/data/test_joinuvs1_map.png new file mode 100644 index 00000000..616d08e8 Binary files /dev/null and b/tests/data/test_joinuvs1_map.png differ diff --git a/tests/data/test_joinuvs2_final.png b/tests/data/test_joinuvs2_final.png new file mode 100644 index 00000000..bce00ed0 Binary files /dev/null and b/tests/data/test_joinuvs2_final.png differ diff --git a/tests/data/test_joinuvs2_map.png b/tests/data/test_joinuvs2_map.png new file mode 100644 index 00000000..ac9c43b0 Binary files /dev/null and b/tests/data/test_joinuvs2_map.png differ diff --git a/tests/data/test_joinverts_final.png b/tests/data/test_joinverts_final.png new file mode 100644 index 00000000..9f85e06f Binary files /dev/null and b/tests/data/test_joinverts_final.png differ diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index e0535846..0b4ab155 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -33,7 +33,11 @@ from pytorch3d.renderer.mesh.shader import ( SoftSilhouetteShader, TexturedSoftPhongShader, ) -from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch +from pytorch3d.structures.meshes import ( + Meshes, + join_meshes_as_batch, + join_meshes_as_scene, +) from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.torus import torus @@ -571,6 +575,288 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5) self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5) + def test_join_uvs(self): + """Meshes with TexturesUV joined into a scene""" + # Test the result of rendering three tori with separate textures. + # The expected result is consistent with rendering them each alone. + # This tests TexturesUV.join_scene with rectangle flipping, + # and we check the form of the merged map as well. + torch.manual_seed(1) + device = torch.device("cuda:0") + + R, T = look_at_view_transform(18, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = RasterizationSettings( + image_size=256, blur_radius=0.0, faces_per_pixel=1 + ) + + lights = PointLights( + device=device, + ambient_color=((1.0, 1.0, 1.0),), + diffuse_color=((0.0, 0.0, 0.0),), + specular_color=((0.0, 0.0, 0.0),), + ) + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader( + device=device, blend_params=blend_params, cameras=cameras, lights=lights + ), + ) + + plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device) + [verts] = plain_torus.verts_list() + verts_shifted1 = verts.clone() + verts_shifted1 *= 0.5 + verts_shifted1[:, 1] += 7 + verts_shifted2 = verts.clone() + verts_shifted2 *= 0.5 + verts_shifted2[:, 1] -= 7 + + [faces] = plain_torus.faces_list() + nocolor = torch.zeros((100, 100), device=device) + color_gradient = torch.linspace(0, 1, steps=100, device=device) + color_gradient1 = color_gradient[None].expand_as(nocolor) + color_gradient2 = color_gradient[:, None].expand_as(nocolor) + colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2) + colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2) + verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device) + verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device) + + for i, align_corners, padding_mode in [ + (0, True, "border"), + (1, False, "border"), + (2, False, "zeros"), + ]: + textures1 = TexturesUV( + maps=[colors1], + faces_uvs=[faces], + verts_uvs=[verts_uvs1], + align_corners=align_corners, + padding_mode=padding_mode, + ) + + # These downsamplings of colors2 are chosen to ensure a flip and a non flip + # when the maps are merged. + # We have maps of size (100, 100), (50, 99) and (99, 50). + textures2 = TexturesUV( + maps=[colors2[::2, :-1]], + faces_uvs=[faces], + verts_uvs=[verts_uvs2], + align_corners=align_corners, + padding_mode=padding_mode, + ) + offset = torch.tensor([0, 0, 0.5], device=device) + textures3 = TexturesUV( + maps=[colors2[:-1, ::2] + offset], + faces_uvs=[faces], + verts_uvs=[verts_uvs2], + align_corners=align_corners, + padding_mode=padding_mode, + ) + mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1) + mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2) + mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3) + mesh = join_meshes_as_scene([mesh1, mesh2, mesh3]) + + output = renderer(mesh)[0, ..., :3].cpu() + output1 = renderer(mesh1)[0, ..., :3].cpu() + output2 = renderer(mesh2)[0, ..., :3].cpu() + output3 = renderer(mesh3)[0, ..., :3].cpu() + # The background color is white and the objects do not overlap, so we can + # predict the merged image by taking the minimum over every channel + merged = torch.min(torch.min(output1, output2), output3) + + image_ref = load_rgb_image(f"test_joinuvs{i}_final.png", DATA_DIR) + map_ref = load_rgb_image(f"test_joinuvs{i}_map.png", DATA_DIR) + + if DEBUG: + Image.fromarray((output.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / f"test_joinuvs{i}_final_.png" + ) + Image.fromarray((output.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / f"test_joinuvs{i}_merged.png" + ) + + Image.fromarray((output1.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / f"test_joinuvs{i}_1.png" + ) + Image.fromarray((output2.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / f"test_joinuvs{i}_2.png" + ) + Image.fromarray((output3.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / f"test_joinuvs{i}_3.png" + ) + Image.fromarray( + (mesh.textures.maps_padded()[0].cpu().numpy() * 255).astype( + np.uint8 + ) + ).save(DATA_DIR / f"test_joinuvs{i}_map_.png") + Image.fromarray( + (mesh2.textures.maps_padded()[0].cpu().numpy() * 255).astype( + np.uint8 + ) + ).save(DATA_DIR / f"test_joinuvs{i}_map2.png") + Image.fromarray( + (mesh3.textures.maps_padded()[0].cpu().numpy() * 255).astype( + np.uint8 + ) + ).save(DATA_DIR / f"test_joinuvs{i}_map3.png") + + self.assertClose(output, merged, atol=0.015) + self.assertClose(output, image_ref, atol=0.05) + self.assertClose(mesh.textures.maps_padded()[0].cpu(), map_ref, atol=0.05) + + def test_join_verts(self): + """Meshes with TexturesVertex joined into a scene""" + # Test the result of rendering two tori with separate textures. + # The expected result is consistent with rendering them each alone. + torch.manual_seed(1) + device = torch.device("cuda:0") + plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device) + [verts] = plain_torus.verts_list() + verts_shifted1 = verts.clone() + verts_shifted1 *= 0.5 + verts_shifted1[:, 1] += 7 + + faces = plain_torus.faces_list() + textures1 = TexturesVertex(verts_features=[torch.rand_like(verts)]) + textures2 = TexturesVertex(verts_features=[torch.rand_like(verts)]) + mesh1 = Meshes(verts=[verts], faces=faces, textures=textures1) + mesh2 = Meshes(verts=[verts_shifted1], faces=faces, textures=textures2) + mesh = join_meshes_as_scene([mesh1, mesh2]) + + R, T = look_at_view_transform(18, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = RasterizationSettings( + image_size=256, blur_radius=0.0, faces_per_pixel=1 + ) + + lights = PointLights( + device=device, + ambient_color=((1.0, 1.0, 1.0),), + diffuse_color=((0.0, 0.0, 0.0),), + specular_color=((0.0, 0.0, 0.0),), + ) + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader( + device=device, blend_params=blend_params, cameras=cameras, lights=lights + ), + ) + + output = renderer(mesh) + + image_ref = load_rgb_image("test_joinverts_final.png", DATA_DIR) + + if DEBUG: + debugging_outputs = [] + for mesh_ in [mesh1, mesh2]: + debugging_outputs.append(renderer(mesh_)) + Image.fromarray( + (output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinverts_final_.png") + Image.fromarray( + (debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinverts_1.png") + Image.fromarray( + (debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinverts_2.png") + + result = output[0, ..., :3].cpu() + self.assertClose(result, image_ref, atol=0.05) + + def test_join_atlas(self): + """Meshes with TexturesAtlas joined into a scene""" + # Test the result of rendering two tori with separate textures. + # The expected result is consistent with rendering them each alone. + torch.manual_seed(1) + device = torch.device("cuda:0") + plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device) + [verts] = plain_torus.verts_list() + verts_shifted1 = verts.clone() + verts_shifted1 *= 1.2 + verts_shifted1[:, 0] += 4 + verts_shifted1[:, 1] += 5 + verts[:, 0] -= 4 + verts[:, 1] -= 4 + + [faces] = plain_torus.faces_list() + map_size = 3 + # Two random atlases. + # The averaging of the random numbers here is not consistent with the + # meaning of the atlases, but makes each face a bit smoother than + # if everything had a random color. + atlas1 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device) + atlas1[:, 1] = 0.5 * atlas1[:, 0] + 0.5 * atlas1[:, 2] + atlas1[:, :, 1] = 0.5 * atlas1[:, :, 0] + 0.5 * atlas1[:, :, 2] + atlas2 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device) + atlas2[:, 1] = 0.5 * atlas2[:, 0] + 0.5 * atlas2[:, 2] + atlas2[:, :, 1] = 0.5 * atlas2[:, :, 0] + 0.5 * atlas2[:, :, 2] + + textures1 = TexturesAtlas(atlas=[atlas1]) + textures2 = TexturesAtlas(atlas=[atlas2]) + mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1) + mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2) + mesh_joined = join_meshes_as_scene([mesh1, mesh2]) + + R, T = look_at_view_transform(18, 0, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + + lights = PointLights( + device=device, + ambient_color=((1.0, 1.0, 1.0),), + diffuse_color=((0.0, 0.0, 0.0),), + specular_color=((0.0, 0.0, 0.0),), + ) + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader( + device=device, blend_params=blend_params, cameras=cameras, lights=lights + ), + ) + + output = renderer(mesh_joined) + + image_ref = load_rgb_image("test_joinatlas_final.png", DATA_DIR) + + if DEBUG: + debugging_outputs = [] + for mesh_ in [mesh1, mesh2]: + debugging_outputs.append(renderer(mesh_)) + Image.fromarray( + (output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinatlas_final_.png") + Image.fromarray( + (debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinatlas_1.png") + Image.fromarray( + (debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_joinatlas_2.png") + + result = output[0, ..., :3].cpu() + self.assertClose(result, image_ref, atol=0.05) + def test_joined_spheres(self): """ Test a list of Meshes can be joined as a single mesh and @@ -595,7 +881,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): sphere_mesh_list.append( Meshes(verts=verts, faces=sphere_list[i].faces_padded()) ) - joined_sphere_mesh = join_mesh(sphere_mesh_list) + joined_sphere_mesh = join_meshes_as_scene(sphere_mesh_list) joined_sphere_mesh.textures = TexturesVertex( verts_features=torch.ones_like(joined_sphere_mesh.verts_padded()) ) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 5e847073..543291d0 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.textures import ( TexturesUV, TexturesVertex, _list_to_padded_wrapper, + pack_rectangles, ) from pytorch3d.structures import Meshes, list_to_packed, packed_to_list from test_meshes import TestMeshes @@ -730,3 +731,80 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): index = torch.tensor([1, 2], dtype=torch.int64) tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + + +class TestRectanglePacking(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + def wrap_pack(self, sizes): + """ + Call the pack_rectangles function, which we want to test, + and return its outputs. + Additionally makes some sanity checks on the output. + """ + res = pack_rectangles(sizes) + total = res.total_size + self.assertGreaterEqual(total[0], 0) + self.assertGreaterEqual(total[1], 0) + mask = torch.zeros(total, dtype=torch.bool) + seen_x_bound = False + seen_y_bound = False + for (in_x, in_y), loc in zip(sizes, res.locations): + self.assertGreaterEqual(loc[0], 0) + self.assertGreaterEqual(loc[1], 0) + placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y) + upper_x = placed_x + loc[0] + upper_y = placed_y + loc[1] + self.assertGreaterEqual(total[0], upper_x) + if total[0] == upper_x: + seen_x_bound = True + self.assertGreaterEqual(total[1], upper_y) + if total[1] == upper_y: + seen_y_bound = True + already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y]) + self.assertEqual(already_taken, 0) + mask[loc[0] : upper_x, loc[1] : upper_y] = 1 + self.assertTrue(seen_x_bound) + self.assertTrue(seen_y_bound) + + self.assertTrue(torch.all(torch.sum(mask, dim=0, dtype=torch.int32) > 0)) + self.assertTrue(torch.all(torch.sum(mask, dim=1, dtype=torch.int32) > 0)) + return res + + def assert_bb(self, sizes, expected): + """ + Apply the pack_rectangles function to sizes and verify the + bounding box dimensions are expected. + """ + self.assertSetEqual(set(self.wrap_pack(sizes).total_size), expected) + + def test_simple(self): + self.assert_bb([(3, 4), (4, 3)], {6, 4}) + self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4}) + + # many squares + self.assert_bb([(2, 2)] * 9, {2, 18}) + + # One big square and many small ones. + self.assert_bb([(3, 3)] + [(1, 1)] * 2, {3, 4}) + self.assert_bb([(3, 3)] + [(1, 1)] * 3, {3, 4}) + self.assert_bb([(3, 3)] + [(1, 1)] * 4, {3, 5}) + self.assert_bb([(3, 3)] + [(1, 1)] * 5, {3, 5}) + self.assert_bb([(1, 1)] * 6 + [(3, 3)], {3, 5}) + self.assert_bb([(3, 3)] + [(1, 1)] * 7, {3, 6}) + + # many identical rectangles + self.assert_bb([(7, 190)] * 4 + [(190, 7)] * 4, {190, 56}) + + # require placing the flipped version of a rectangle + self.assert_bb([(1, 100), (5, 96), (4, 5)], {100, 6}) + + def test_random(self): + for _ in range(5): + vals = torch.randint(size=(20, 2), low=1, high=18) + sizes = [] + for j in range(vals.shape[0]): + sizes.append((int(vals[j, 0]), int(vals[j, 1]))) + self.wrap_pack(sizes)