diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 5c2976b0..abfc0a54 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -10,7 +10,7 @@ from pytorch3d.ops import interpolate_face_attributes from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list from torch.nn.functional import interpolate -from .utils import pack_rectangles +from .utils import PackedRectangle, Rectangle, pack_unique_rectangles # This file contains classes and helper functions for texturing. @@ -1028,14 +1028,13 @@ class TexturesUV(TexturesBase): maps_list = [] faces_uvs_list += self.faces_uvs_list() verts_uvs_list += self.verts_uvs_list() - maps_list += list(self.maps_padded().unbind(0)) + maps_list += self.maps_list() num_faces_per_mesh = self._num_faces_per_mesh for tex in textures: verts_uvs_list += tex.verts_uvs_list() faces_uvs_list += tex.faces_uvs_list() num_faces_per_mesh += tex._num_faces_per_mesh - tex_map_list = list(tex.maps_padded().unbind(0)) - maps_list += tex_map_list + maps_list += tex.maps_list() new_tex = self.__class__( maps=maps_list, @@ -1048,10 +1047,7 @@ class TexturesUV(TexturesBase): return new_tex def _place_map_into_single_map( - self, - single_map: torch.Tensor, - map_: torch.Tensor, - location: Tuple[int, int, bool], # (x,y) and whether flipped + self, single_map: torch.Tensor, map_: torch.Tensor, location: PackedRectangle ) -> None: """ Copy map into a larger tensor single_map at the destination specified by location. @@ -1064,11 +1060,11 @@ class TexturesUV(TexturesBase): map_: (H, W, 3) source data location: where to place map """ - do_flip = location[2] + do_flip = location.flipped source = map_.transpose(0, 1) if do_flip else map_ border_width = 0 if self.align_corners else 1 - lower_u = location[0] + border_width - lower_v = location[1] + border_width + lower_u = location.x + border_width + lower_v = location.y + border_width upper_u = lower_u + source.shape[0] upper_v = lower_v + source.shape[1] single_map[lower_u:upper_u, lower_v:upper_v] = source @@ -1102,19 +1098,23 @@ class TexturesUV(TexturesBase): If align_corners=False, we need to add an artificial border around every map. - We use the function `pack_rectangles` to provide a layout for the - single map. _place_map_into_single_map is used to copy the maps - into the single map. The merging of verts_uvs and faces_uvs are - handled locally in this function. + We use the function `pack_unique_rectangles` to provide a layout for + the single map. This means that if self was created with a list of maps, + and to() has not been called, and there were two maps which were exactly + the same tensor object, then they will become the same data in the unified map. + _place_map_into_single_map is used to copy the maps into the single map. + The merging of verts_uvs and faces_uvs is handled locally in this function. """ maps = self.maps_list() heights_and_widths = [] extra_border = 0 if self.align_corners else 2 for map_ in maps: heights_and_widths.append( - (map_.shape[0] + extra_border, map_.shape[1] + extra_border) + Rectangle( + map_.shape[0] + extra_border, map_.shape[1] + extra_border, id(map_) + ) ) - merging_plan = pack_rectangles(heights_and_widths) + merging_plan = pack_unique_rectangles(heights_and_widths) # pyre-fixme[16]: `Tensor` has no attribute `new_zeros`. single_map = maps[0].new_zeros((*merging_plan.total_size, 3)) verts_uvs = self.verts_uvs_list() @@ -1122,8 +1122,9 @@ class TexturesUV(TexturesBase): for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs): new_uvs = uvs.clone() - self._place_map_into_single_map(single_map, map_, loc) - do_flip = loc[2] + if loc.is_first: + self._place_map_into_single_map(single_map, map_, loc) + do_flip = loc.flipped x_shape = map_.shape[1] if do_flip else map_.shape[0] y_shape = map_.shape[0] if do_flip else map_.shape[1] @@ -1164,9 +1165,9 @@ class TexturesUV(TexturesBase): denom_y = merging_plan.total_size[1] - one_if_align scale_y = y_shape - one_if_align new_uvs[:, 1] *= scale_x / denom_x - new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x + new_uvs[:, 1] += (loc.x + one_if_not_align) / denom_x new_uvs[:, 0] *= scale_y / denom_y - new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y + new_uvs[:, 0] += (loc.y + one_if_not_align) / denom_y verts_uvs_merged.append(new_uvs) diff --git a/pytorch3d/renderer/mesh/utils.py b/pytorch3d/renderer/mesh/utils.py index 3e9eb998..5ee9e2b5 100644 --- a/pytorch3d/renderer/mesh/utils.py +++ b/pytorch3d/renderer/mesh/utils.py @@ -64,6 +64,25 @@ def _interpolate_zbuf( # ----------- Rectangle Packing -------------------- # + +class Rectangle(NamedTuple): + xsize: int + ysize: int + identifier: int + + +class PackedRectangle(NamedTuple): + x: int + y: int + flipped: bool + is_first: bool + + +class PackedRectangles(NamedTuple): + total_size: Tuple[int, int] + locations: List[PackedRectangle] + + # Note the order of members matters here because it determines the queue order. # We want to place longer rectangles first. class _UnplacedRectangle(NamedTuple): @@ -74,7 +93,7 @@ class _UnplacedRectangle(NamedTuple): def _try_place_rectangle( rect: _UnplacedRectangle, - placed_so_far: List[Tuple[int, int, bool]], + placed_so_far: List[PackedRectangle], occupied: List[Tuple[int, int]], ) -> bool: """ @@ -156,10 +175,11 @@ def _try_place_rectangle( current_start_idx = idx if currently_packed >= needed_height: current_max_width = max(interval[0], current_max_width) - placed_so_far[rect.ind] = ( + placed_so_far[rect.ind] = PackedRectangle( current_max_width, occupied[current_start_idx - 1][1], rect.flipped, + True, ) new_occupied = ( current_max_width + rect.size[0], @@ -182,11 +202,6 @@ def _try_place_rectangle( return False -class PackedRectangles(NamedTuple): - total_size: Tuple[int, int] - locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped - - def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles: """ Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating @@ -200,7 +215,9 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles: Returns: total_size: size of total large rectangle - rectangles: location for each of the input rectangles + rectangles: location for each of the input rectangles. + This includes whether they are flipped. + The is_first field is always True. """ if len(sizes) < 2: @@ -213,14 +230,14 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles: else: queue.append(_UnplacedRectangle((size[0], size[1]), i, False)) queue.sort() - placed_so_far = [(-1, -1, False)] * len(sizes) + placed_so_far = [PackedRectangle(-1, -1, False, False)] * len(sizes) biggest = queue.pop() total_width, current_height = biggest.size - placed_so_far[biggest.ind] = (0, 0, biggest.flipped) + placed_so_far[biggest.ind] = PackedRectangle(0, 0, biggest.flipped, True) second = queue.pop() - placed_so_far[second.ind] = (0, current_height, second.flipped) + placed_so_far[second.ind] = PackedRectangle(0, current_height, second.flipped, True) current_height += second.size[1] occupied = [biggest.size, (second.size[0], current_height)] @@ -236,8 +253,63 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles: # rect wasn't placed in the current bounding box, # so we add extra space to fit it in. - placed_so_far[rect.ind] = (0, current_height, rect.flipped) + placed_so_far[rect.ind] = PackedRectangle(0, current_height, rect.flipped, True) current_height += rect.size[1] occupied.append((rect.size[0], current_height)) return PackedRectangles((total_width, current_height), placed_so_far) + + +def pack_unique_rectangles(rectangles: List[Rectangle]) -> PackedRectangles: + """ + Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating + a rectangle by 90 degrees) is allowed. Inputs are deduplicated by their + identifier. + + This is a wrapper around pack_rectangles, where inputs come with an + identifier. In particular, it calls pack_rectangles for the deduplicated inputs, + then returns the values for all the inputs. The output for all rectangles with + the same identifier will be the same, except that only the first one will have + the is_first field True. + + This is used to join several uv maps into a single scene, see + TexturesUV.join_scene. + + Args: + rectangles: List of sizes of rectangles to pack + + Returns: + total_size: size of total large rectangle + rectangles: location for each of the input rectangles. + This includes whether they are flipped. + The is_first field is true for the first rectangle + with each identifier. + """ + + if len(rectangles) < 2: + raise ValueError("Cannot pack less than two boxes") + + input_map = {} + input_indices: List[Tuple[int, bool]] = [] + unique_input_sizes: List[Tuple[int, int]] = [] + for rectangle in rectangles: + if rectangle.identifier not in input_map: + unique_index = len(unique_input_sizes) + unique_input_sizes.append((rectangle.xsize, rectangle.ysize)) + input_map[rectangle.identifier] = unique_index + input_indices.append((unique_index, True)) + else: + unique_index = input_map[rectangle.identifier] + input_indices.append((unique_index, False)) + + if len(unique_input_sizes) == 1: + first = [PackedRectangle(0, 0, False, True)] + rest = (len(rectangles) - 1) * [PackedRectangle(0, 0, False, False)] + return PackedRectangles(unique_input_sizes[0], first + rest) + + total_size, unique_locations = pack_rectangles(unique_input_sizes) + full_locations = [] + for input_index, first in input_indices: + full_locations.append(unique_locations[input_index]._replace(is_first=first)) + + return PackedRectangles(total_size, full_locations) diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index d737529e..c9509d78 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -652,6 +652,9 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): verts_shifted2 = verts.clone() verts_shifted2 *= 0.5 verts_shifted2[:, 1] -= 7 + verts_shifted3 = verts.clone() + verts_shifted3 *= 0.5 + verts_shifted3[:, 1] -= 700 [faces] = plain_torus.faces_list() nocolor = torch.zeros((100, 100), device=device) @@ -697,7 +700,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1) mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2) mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3) - mesh = join_meshes_as_scene([mesh1, mesh2, mesh3]) + # mesh4 is like mesh1 but outside the field of view. It is here to test + # that having another texture with the same map doesn't produce + # two copies in the joined map. + mesh4 = Meshes(verts=[verts_shifted3], faces=[faces], textures=textures1) + mesh = join_meshes_as_scene([mesh1, mesh2, mesh3, mesh4]) output = renderer(mesh)[0, ..., :3].cpu() output1 = renderer(mesh1)[0, ..., :3].cpu() diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 3f9fddd7..c82377c4 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -12,7 +12,11 @@ from pytorch3d.renderer.mesh.textures import ( TexturesUV, TexturesVertex, _list_to_padded_wrapper, +) +from pytorch3d.renderer.mesh.utils import ( + Rectangle, pack_rectangles, + pack_unique_rectangles, ) from pytorch3d.structures import Meshes, list_to_packed, packed_to_list from test_meshes import init_mesh @@ -873,21 +877,24 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase): mask = torch.zeros(total, dtype=torch.bool) seen_x_bound = False seen_y_bound = False - for (in_x, in_y), loc in zip(sizes, res.locations): - self.assertGreaterEqual(loc[0], 0) - self.assertGreaterEqual(loc[1], 0) - placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y) - upper_x = placed_x + loc[0] - upper_y = placed_y + loc[1] + for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip( + sizes, res.locations + ): + self.assertTrue(is_first) + self.assertGreaterEqual(out_x, 0) + self.assertGreaterEqual(out_y, 0) + placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y) + upper_x = placed_x + out_x + upper_y = placed_y + out_y self.assertGreaterEqual(total[0], upper_x) if total[0] == upper_x: seen_x_bound = True self.assertGreaterEqual(total[1], upper_y) if total[1] == upper_y: seen_y_bound = True - already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y]) + already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y]) self.assertEqual(already_taken, 0) - mask[loc[0] : upper_x, loc[1] : upper_y] = 1 + mask[out_x:upper_x, out_y:upper_y] = 1 self.assertTrue(seen_x_bound) self.assertTrue(seen_y_bound) @@ -930,3 +937,29 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase): for j in range(vals.shape[0]): sizes.append((int(vals[j, 0]), int(vals[j, 1]))) self.wrap_pack(sizes) + + def test_all_identical(self): + sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3 + total_size, locations = pack_unique_rectangles(sizes) + self.assertEqual(total_size, (61, 82)) + self.assertEqual(len(locations), 3) + for i, (x, y, is_flipped, is_first) in enumerate(locations): + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertFalse(is_flipped) + self.assertEqual(is_first, i == 0) + + def test_one_different_id(self): + sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3 + sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3) + total_size, locations = pack_unique_rectangles(sizes) + self.assertEqual(total_size, (82, 122)) + self.assertEqual(len(locations), 6) + for i, (x, y, is_flipped, is_first) in enumerate(locations): + self.assertTrue(is_flipped) + self.assertEqual(is_first, i % 3 == 0) + self.assertEqual(x, 0) + if i < 3: + self.assertEqual(y, 61) + else: + self.assertEqual(y, 0)