Deduplicate texture maps when joining

Summary:
If you join several meshes which have TexturesUV textures using join_meshes_as_scene then we amalgamate all the texture images in to a single one. This now checks if some of the images are equal (i.e. the tensors are the same tensor, in the `is` sense; they have the same `id` in Python) and only uses one copy if they are.

I have an example of a massive scene made of several textured meshes with some shared, where this makes the difference between fitting the data on the GPU and not.

Reviewed By: theschnitz

Differential Revision: D25982364

fbshipit-source-id: a8228805f38475c796302e27328a340d9b56c8ef
This commit is contained in:
Jeremy Reizenstein 2021-05-26 04:52:46 -07:00 committed by Facebook GitHub Bot
parent cd5af2521a
commit e12a08133f
4 changed files with 155 additions and 42 deletions

View File

@ -10,7 +10,7 @@ from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
from torch.nn.functional import interpolate
from .utils import pack_rectangles
from .utils import PackedRectangle, Rectangle, pack_unique_rectangles
# This file contains classes and helper functions for texturing.
@ -1028,14 +1028,13 @@ class TexturesUV(TexturesBase):
maps_list = []
faces_uvs_list += self.faces_uvs_list()
verts_uvs_list += self.verts_uvs_list()
maps_list += list(self.maps_padded().unbind(0))
maps_list += self.maps_list()
num_faces_per_mesh = self._num_faces_per_mesh
for tex in textures:
verts_uvs_list += tex.verts_uvs_list()
faces_uvs_list += tex.faces_uvs_list()
num_faces_per_mesh += tex._num_faces_per_mesh
tex_map_list = list(tex.maps_padded().unbind(0))
maps_list += tex_map_list
maps_list += tex.maps_list()
new_tex = self.__class__(
maps=maps_list,
@ -1048,10 +1047,7 @@ class TexturesUV(TexturesBase):
return new_tex
def _place_map_into_single_map(
self,
single_map: torch.Tensor,
map_: torch.Tensor,
location: Tuple[int, int, bool], # (x,y) and whether flipped
self, single_map: torch.Tensor, map_: torch.Tensor, location: PackedRectangle
) -> None:
"""
Copy map into a larger tensor single_map at the destination specified by location.
@ -1064,11 +1060,11 @@ class TexturesUV(TexturesBase):
map_: (H, W, 3) source data
location: where to place map
"""
do_flip = location[2]
do_flip = location.flipped
source = map_.transpose(0, 1) if do_flip else map_
border_width = 0 if self.align_corners else 1
lower_u = location[0] + border_width
lower_v = location[1] + border_width
lower_u = location.x + border_width
lower_v = location.y + border_width
upper_u = lower_u + source.shape[0]
upper_v = lower_v + source.shape[1]
single_map[lower_u:upper_u, lower_v:upper_v] = source
@ -1102,19 +1098,23 @@ class TexturesUV(TexturesBase):
If align_corners=False, we need to add an artificial border around
every map.
We use the function `pack_rectangles` to provide a layout for the
single map. _place_map_into_single_map is used to copy the maps
into the single map. The merging of verts_uvs and faces_uvs are
handled locally in this function.
We use the function `pack_unique_rectangles` to provide a layout for
the single map. This means that if self was created with a list of maps,
and to() has not been called, and there were two maps which were exactly
the same tensor object, then they will become the same data in the unified map.
_place_map_into_single_map is used to copy the maps into the single map.
The merging of verts_uvs and faces_uvs is handled locally in this function.
"""
maps = self.maps_list()
heights_and_widths = []
extra_border = 0 if self.align_corners else 2
for map_ in maps:
heights_and_widths.append(
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
Rectangle(
map_.shape[0] + extra_border, map_.shape[1] + extra_border, id(map_)
)
)
merging_plan = pack_rectangles(heights_and_widths)
merging_plan = pack_unique_rectangles(heights_and_widths)
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
verts_uvs = self.verts_uvs_list()
@ -1122,8 +1122,9 @@ class TexturesUV(TexturesBase):
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
new_uvs = uvs.clone()
self._place_map_into_single_map(single_map, map_, loc)
do_flip = loc[2]
if loc.is_first:
self._place_map_into_single_map(single_map, map_, loc)
do_flip = loc.flipped
x_shape = map_.shape[1] if do_flip else map_.shape[0]
y_shape = map_.shape[0] if do_flip else map_.shape[1]
@ -1164,9 +1165,9 @@ class TexturesUV(TexturesBase):
denom_y = merging_plan.total_size[1] - one_if_align
scale_y = y_shape - one_if_align
new_uvs[:, 1] *= scale_x / denom_x
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
new_uvs[:, 1] += (loc.x + one_if_not_align) / denom_x
new_uvs[:, 0] *= scale_y / denom_y
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
new_uvs[:, 0] += (loc.y + one_if_not_align) / denom_y
verts_uvs_merged.append(new_uvs)

View File

@ -64,6 +64,25 @@ def _interpolate_zbuf(
# ----------- Rectangle Packing -------------------- #
class Rectangle(NamedTuple):
xsize: int
ysize: int
identifier: int
class PackedRectangle(NamedTuple):
x: int
y: int
flipped: bool
is_first: bool
class PackedRectangles(NamedTuple):
total_size: Tuple[int, int]
locations: List[PackedRectangle]
# Note the order of members matters here because it determines the queue order.
# We want to place longer rectangles first.
class _UnplacedRectangle(NamedTuple):
@ -74,7 +93,7 @@ class _UnplacedRectangle(NamedTuple):
def _try_place_rectangle(
rect: _UnplacedRectangle,
placed_so_far: List[Tuple[int, int, bool]],
placed_so_far: List[PackedRectangle],
occupied: List[Tuple[int, int]],
) -> bool:
"""
@ -156,10 +175,11 @@ def _try_place_rectangle(
current_start_idx = idx
if currently_packed >= needed_height:
current_max_width = max(interval[0], current_max_width)
placed_so_far[rect.ind] = (
placed_so_far[rect.ind] = PackedRectangle(
current_max_width,
occupied[current_start_idx - 1][1],
rect.flipped,
True,
)
new_occupied = (
current_max_width + rect.size[0],
@ -182,11 +202,6 @@ def _try_place_rectangle(
return False
class PackedRectangles(NamedTuple):
total_size: Tuple[int, int]
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
"""
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
@ -200,7 +215,9 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
Returns:
total_size: size of total large rectangle
rectangles: location for each of the input rectangles
rectangles: location for each of the input rectangles.
This includes whether they are flipped.
The is_first field is always True.
"""
if len(sizes) < 2:
@ -213,14 +230,14 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
else:
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
queue.sort()
placed_so_far = [(-1, -1, False)] * len(sizes)
placed_so_far = [PackedRectangle(-1, -1, False, False)] * len(sizes)
biggest = queue.pop()
total_width, current_height = biggest.size
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
placed_so_far[biggest.ind] = PackedRectangle(0, 0, biggest.flipped, True)
second = queue.pop()
placed_so_far[second.ind] = (0, current_height, second.flipped)
placed_so_far[second.ind] = PackedRectangle(0, current_height, second.flipped, True)
current_height += second.size[1]
occupied = [biggest.size, (second.size[0], current_height)]
@ -236,8 +253,63 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
# rect wasn't placed in the current bounding box,
# so we add extra space to fit it in.
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
placed_so_far[rect.ind] = PackedRectangle(0, current_height, rect.flipped, True)
current_height += rect.size[1]
occupied.append((rect.size[0], current_height))
return PackedRectangles((total_width, current_height), placed_so_far)
def pack_unique_rectangles(rectangles: List[Rectangle]) -> PackedRectangles:
"""
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
a rectangle by 90 degrees) is allowed. Inputs are deduplicated by their
identifier.
This is a wrapper around pack_rectangles, where inputs come with an
identifier. In particular, it calls pack_rectangles for the deduplicated inputs,
then returns the values for all the inputs. The output for all rectangles with
the same identifier will be the same, except that only the first one will have
the is_first field True.
This is used to join several uv maps into a single scene, see
TexturesUV.join_scene.
Args:
rectangles: List of sizes of rectangles to pack
Returns:
total_size: size of total large rectangle
rectangles: location for each of the input rectangles.
This includes whether they are flipped.
The is_first field is true for the first rectangle
with each identifier.
"""
if len(rectangles) < 2:
raise ValueError("Cannot pack less than two boxes")
input_map = {}
input_indices: List[Tuple[int, bool]] = []
unique_input_sizes: List[Tuple[int, int]] = []
for rectangle in rectangles:
if rectangle.identifier not in input_map:
unique_index = len(unique_input_sizes)
unique_input_sizes.append((rectangle.xsize, rectangle.ysize))
input_map[rectangle.identifier] = unique_index
input_indices.append((unique_index, True))
else:
unique_index = input_map[rectangle.identifier]
input_indices.append((unique_index, False))
if len(unique_input_sizes) == 1:
first = [PackedRectangle(0, 0, False, True)]
rest = (len(rectangles) - 1) * [PackedRectangle(0, 0, False, False)]
return PackedRectangles(unique_input_sizes[0], first + rest)
total_size, unique_locations = pack_rectangles(unique_input_sizes)
full_locations = []
for input_index, first in input_indices:
full_locations.append(unique_locations[input_index]._replace(is_first=first))
return PackedRectangles(total_size, full_locations)

View File

@ -652,6 +652,9 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
verts_shifted2 = verts.clone()
verts_shifted2 *= 0.5
verts_shifted2[:, 1] -= 7
verts_shifted3 = verts.clone()
verts_shifted3 *= 0.5
verts_shifted3[:, 1] -= 700
[faces] = plain_torus.faces_list()
nocolor = torch.zeros((100, 100), device=device)
@ -697,7 +700,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
# mesh4 is like mesh1 but outside the field of view. It is here to test
# that having another texture with the same map doesn't produce
# two copies in the joined map.
mesh4 = Meshes(verts=[verts_shifted3], faces=[faces], textures=textures1)
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3, mesh4])
output = renderer(mesh)[0, ..., :3].cpu()
output1 = renderer(mesh1)[0, ..., :3].cpu()

View File

@ -12,7 +12,11 @@ from pytorch3d.renderer.mesh.textures import (
TexturesUV,
TexturesVertex,
_list_to_padded_wrapper,
)
from pytorch3d.renderer.mesh.utils import (
Rectangle,
pack_rectangles,
pack_unique_rectangles,
)
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
from test_meshes import init_mesh
@ -873,21 +877,24 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
mask = torch.zeros(total, dtype=torch.bool)
seen_x_bound = False
seen_y_bound = False
for (in_x, in_y), loc in zip(sizes, res.locations):
self.assertGreaterEqual(loc[0], 0)
self.assertGreaterEqual(loc[1], 0)
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
upper_x = placed_x + loc[0]
upper_y = placed_y + loc[1]
for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip(
sizes, res.locations
):
self.assertTrue(is_first)
self.assertGreaterEqual(out_x, 0)
self.assertGreaterEqual(out_y, 0)
placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y)
upper_x = placed_x + out_x
upper_y = placed_y + out_y
self.assertGreaterEqual(total[0], upper_x)
if total[0] == upper_x:
seen_x_bound = True
self.assertGreaterEqual(total[1], upper_y)
if total[1] == upper_y:
seen_y_bound = True
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y])
self.assertEqual(already_taken, 0)
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
mask[out_x:upper_x, out_y:upper_y] = 1
self.assertTrue(seen_x_bound)
self.assertTrue(seen_y_bound)
@ -930,3 +937,29 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
for j in range(vals.shape[0]):
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
self.wrap_pack(sizes)
def test_all_identical(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (61, 82))
self.assertEqual(len(locations), 3)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertEqual(x, 0)
self.assertEqual(y, 0)
self.assertFalse(is_flipped)
self.assertEqual(is_first, i == 0)
def test_one_different_id(self):
sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3
sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3)
total_size, locations = pack_unique_rectangles(sizes)
self.assertEqual(total_size, (82, 122))
self.assertEqual(len(locations), 6)
for i, (x, y, is_flipped, is_first) in enumerate(locations):
self.assertTrue(is_flipped)
self.assertEqual(is_first, i % 3 == 0)
self.assertEqual(x, 0)
if i < 3:
self.assertEqual(y, 61)
else:
self.assertEqual(y, 0)