mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Deduplicate texture maps when joining
Summary: If you join several meshes which have TexturesUV textures using join_meshes_as_scene then we amalgamate all the texture images in to a single one. This now checks if some of the images are equal (i.e. the tensors are the same tensor, in the `is` sense; they have the same `id` in Python) and only uses one copy if they are. I have an example of a massive scene made of several textured meshes with some shared, where this makes the difference between fitting the data on the GPU and not. Reviewed By: theschnitz Differential Revision: D25982364 fbshipit-source-id: a8228805f38475c796302e27328a340d9b56c8ef
This commit is contained in:
parent
cd5af2521a
commit
e12a08133f
@ -10,7 +10,7 @@ from pytorch3d.ops import interpolate_face_attributes
|
|||||||
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
|
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
|
|
||||||
from .utils import pack_rectangles
|
from .utils import PackedRectangle, Rectangle, pack_unique_rectangles
|
||||||
|
|
||||||
|
|
||||||
# This file contains classes and helper functions for texturing.
|
# This file contains classes and helper functions for texturing.
|
||||||
@ -1028,14 +1028,13 @@ class TexturesUV(TexturesBase):
|
|||||||
maps_list = []
|
maps_list = []
|
||||||
faces_uvs_list += self.faces_uvs_list()
|
faces_uvs_list += self.faces_uvs_list()
|
||||||
verts_uvs_list += self.verts_uvs_list()
|
verts_uvs_list += self.verts_uvs_list()
|
||||||
maps_list += list(self.maps_padded().unbind(0))
|
maps_list += self.maps_list()
|
||||||
num_faces_per_mesh = self._num_faces_per_mesh
|
num_faces_per_mesh = self._num_faces_per_mesh
|
||||||
for tex in textures:
|
for tex in textures:
|
||||||
verts_uvs_list += tex.verts_uvs_list()
|
verts_uvs_list += tex.verts_uvs_list()
|
||||||
faces_uvs_list += tex.faces_uvs_list()
|
faces_uvs_list += tex.faces_uvs_list()
|
||||||
num_faces_per_mesh += tex._num_faces_per_mesh
|
num_faces_per_mesh += tex._num_faces_per_mesh
|
||||||
tex_map_list = list(tex.maps_padded().unbind(0))
|
maps_list += tex.maps_list()
|
||||||
maps_list += tex_map_list
|
|
||||||
|
|
||||||
new_tex = self.__class__(
|
new_tex = self.__class__(
|
||||||
maps=maps_list,
|
maps=maps_list,
|
||||||
@ -1048,10 +1047,7 @@ class TexturesUV(TexturesBase):
|
|||||||
return new_tex
|
return new_tex
|
||||||
|
|
||||||
def _place_map_into_single_map(
|
def _place_map_into_single_map(
|
||||||
self,
|
self, single_map: torch.Tensor, map_: torch.Tensor, location: PackedRectangle
|
||||||
single_map: torch.Tensor,
|
|
||||||
map_: torch.Tensor,
|
|
||||||
location: Tuple[int, int, bool], # (x,y) and whether flipped
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Copy map into a larger tensor single_map at the destination specified by location.
|
Copy map into a larger tensor single_map at the destination specified by location.
|
||||||
@ -1064,11 +1060,11 @@ class TexturesUV(TexturesBase):
|
|||||||
map_: (H, W, 3) source data
|
map_: (H, W, 3) source data
|
||||||
location: where to place map
|
location: where to place map
|
||||||
"""
|
"""
|
||||||
do_flip = location[2]
|
do_flip = location.flipped
|
||||||
source = map_.transpose(0, 1) if do_flip else map_
|
source = map_.transpose(0, 1) if do_flip else map_
|
||||||
border_width = 0 if self.align_corners else 1
|
border_width = 0 if self.align_corners else 1
|
||||||
lower_u = location[0] + border_width
|
lower_u = location.x + border_width
|
||||||
lower_v = location[1] + border_width
|
lower_v = location.y + border_width
|
||||||
upper_u = lower_u + source.shape[0]
|
upper_u = lower_u + source.shape[0]
|
||||||
upper_v = lower_v + source.shape[1]
|
upper_v = lower_v + source.shape[1]
|
||||||
single_map[lower_u:upper_u, lower_v:upper_v] = source
|
single_map[lower_u:upper_u, lower_v:upper_v] = source
|
||||||
@ -1102,19 +1098,23 @@ class TexturesUV(TexturesBase):
|
|||||||
If align_corners=False, we need to add an artificial border around
|
If align_corners=False, we need to add an artificial border around
|
||||||
every map.
|
every map.
|
||||||
|
|
||||||
We use the function `pack_rectangles` to provide a layout for the
|
We use the function `pack_unique_rectangles` to provide a layout for
|
||||||
single map. _place_map_into_single_map is used to copy the maps
|
the single map. This means that if self was created with a list of maps,
|
||||||
into the single map. The merging of verts_uvs and faces_uvs are
|
and to() has not been called, and there were two maps which were exactly
|
||||||
handled locally in this function.
|
the same tensor object, then they will become the same data in the unified map.
|
||||||
|
_place_map_into_single_map is used to copy the maps into the single map.
|
||||||
|
The merging of verts_uvs and faces_uvs is handled locally in this function.
|
||||||
"""
|
"""
|
||||||
maps = self.maps_list()
|
maps = self.maps_list()
|
||||||
heights_and_widths = []
|
heights_and_widths = []
|
||||||
extra_border = 0 if self.align_corners else 2
|
extra_border = 0 if self.align_corners else 2
|
||||||
for map_ in maps:
|
for map_ in maps:
|
||||||
heights_and_widths.append(
|
heights_and_widths.append(
|
||||||
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
|
Rectangle(
|
||||||
|
map_.shape[0] + extra_border, map_.shape[1] + extra_border, id(map_)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
merging_plan = pack_rectangles(heights_and_widths)
|
merging_plan = pack_unique_rectangles(heights_and_widths)
|
||||||
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
|
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
|
||||||
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
|
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
|
||||||
verts_uvs = self.verts_uvs_list()
|
verts_uvs = self.verts_uvs_list()
|
||||||
@ -1122,8 +1122,9 @@ class TexturesUV(TexturesBase):
|
|||||||
|
|
||||||
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
|
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
|
||||||
new_uvs = uvs.clone()
|
new_uvs = uvs.clone()
|
||||||
self._place_map_into_single_map(single_map, map_, loc)
|
if loc.is_first:
|
||||||
do_flip = loc[2]
|
self._place_map_into_single_map(single_map, map_, loc)
|
||||||
|
do_flip = loc.flipped
|
||||||
x_shape = map_.shape[1] if do_flip else map_.shape[0]
|
x_shape = map_.shape[1] if do_flip else map_.shape[0]
|
||||||
y_shape = map_.shape[0] if do_flip else map_.shape[1]
|
y_shape = map_.shape[0] if do_flip else map_.shape[1]
|
||||||
|
|
||||||
@ -1164,9 +1165,9 @@ class TexturesUV(TexturesBase):
|
|||||||
denom_y = merging_plan.total_size[1] - one_if_align
|
denom_y = merging_plan.total_size[1] - one_if_align
|
||||||
scale_y = y_shape - one_if_align
|
scale_y = y_shape - one_if_align
|
||||||
new_uvs[:, 1] *= scale_x / denom_x
|
new_uvs[:, 1] *= scale_x / denom_x
|
||||||
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
|
new_uvs[:, 1] += (loc.x + one_if_not_align) / denom_x
|
||||||
new_uvs[:, 0] *= scale_y / denom_y
|
new_uvs[:, 0] *= scale_y / denom_y
|
||||||
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
|
new_uvs[:, 0] += (loc.y + one_if_not_align) / denom_y
|
||||||
|
|
||||||
verts_uvs_merged.append(new_uvs)
|
verts_uvs_merged.append(new_uvs)
|
||||||
|
|
||||||
|
@ -64,6 +64,25 @@ def _interpolate_zbuf(
|
|||||||
|
|
||||||
# ----------- Rectangle Packing -------------------- #
|
# ----------- Rectangle Packing -------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
class Rectangle(NamedTuple):
|
||||||
|
xsize: int
|
||||||
|
ysize: int
|
||||||
|
identifier: int
|
||||||
|
|
||||||
|
|
||||||
|
class PackedRectangle(NamedTuple):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
flipped: bool
|
||||||
|
is_first: bool
|
||||||
|
|
||||||
|
|
||||||
|
class PackedRectangles(NamedTuple):
|
||||||
|
total_size: Tuple[int, int]
|
||||||
|
locations: List[PackedRectangle]
|
||||||
|
|
||||||
|
|
||||||
# Note the order of members matters here because it determines the queue order.
|
# Note the order of members matters here because it determines the queue order.
|
||||||
# We want to place longer rectangles first.
|
# We want to place longer rectangles first.
|
||||||
class _UnplacedRectangle(NamedTuple):
|
class _UnplacedRectangle(NamedTuple):
|
||||||
@ -74,7 +93,7 @@ class _UnplacedRectangle(NamedTuple):
|
|||||||
|
|
||||||
def _try_place_rectangle(
|
def _try_place_rectangle(
|
||||||
rect: _UnplacedRectangle,
|
rect: _UnplacedRectangle,
|
||||||
placed_so_far: List[Tuple[int, int, bool]],
|
placed_so_far: List[PackedRectangle],
|
||||||
occupied: List[Tuple[int, int]],
|
occupied: List[Tuple[int, int]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -156,10 +175,11 @@ def _try_place_rectangle(
|
|||||||
current_start_idx = idx
|
current_start_idx = idx
|
||||||
if currently_packed >= needed_height:
|
if currently_packed >= needed_height:
|
||||||
current_max_width = max(interval[0], current_max_width)
|
current_max_width = max(interval[0], current_max_width)
|
||||||
placed_so_far[rect.ind] = (
|
placed_so_far[rect.ind] = PackedRectangle(
|
||||||
current_max_width,
|
current_max_width,
|
||||||
occupied[current_start_idx - 1][1],
|
occupied[current_start_idx - 1][1],
|
||||||
rect.flipped,
|
rect.flipped,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
new_occupied = (
|
new_occupied = (
|
||||||
current_max_width + rect.size[0],
|
current_max_width + rect.size[0],
|
||||||
@ -182,11 +202,6 @@ def _try_place_rectangle(
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class PackedRectangles(NamedTuple):
|
|
||||||
total_size: Tuple[int, int]
|
|
||||||
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped
|
|
||||||
|
|
||||||
|
|
||||||
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
||||||
"""
|
"""
|
||||||
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
|
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
|
||||||
@ -200,7 +215,9 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
total_size: size of total large rectangle
|
total_size: size of total large rectangle
|
||||||
rectangles: location for each of the input rectangles
|
rectangles: location for each of the input rectangles.
|
||||||
|
This includes whether they are flipped.
|
||||||
|
The is_first field is always True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(sizes) < 2:
|
if len(sizes) < 2:
|
||||||
@ -213,14 +230,14 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
|||||||
else:
|
else:
|
||||||
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
|
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
|
||||||
queue.sort()
|
queue.sort()
|
||||||
placed_so_far = [(-1, -1, False)] * len(sizes)
|
placed_so_far = [PackedRectangle(-1, -1, False, False)] * len(sizes)
|
||||||
|
|
||||||
biggest = queue.pop()
|
biggest = queue.pop()
|
||||||
total_width, current_height = biggest.size
|
total_width, current_height = biggest.size
|
||||||
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
|
placed_so_far[biggest.ind] = PackedRectangle(0, 0, biggest.flipped, True)
|
||||||
|
|
||||||
second = queue.pop()
|
second = queue.pop()
|
||||||
placed_so_far[second.ind] = (0, current_height, second.flipped)
|
placed_so_far[second.ind] = PackedRectangle(0, current_height, second.flipped, True)
|
||||||
current_height += second.size[1]
|
current_height += second.size[1]
|
||||||
occupied = [biggest.size, (second.size[0], current_height)]
|
occupied = [biggest.size, (second.size[0], current_height)]
|
||||||
|
|
||||||
@ -236,8 +253,63 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
|||||||
|
|
||||||
# rect wasn't placed in the current bounding box,
|
# rect wasn't placed in the current bounding box,
|
||||||
# so we add extra space to fit it in.
|
# so we add extra space to fit it in.
|
||||||
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
|
placed_so_far[rect.ind] = PackedRectangle(0, current_height, rect.flipped, True)
|
||||||
current_height += rect.size[1]
|
current_height += rect.size[1]
|
||||||
occupied.append((rect.size[0], current_height))
|
occupied.append((rect.size[0], current_height))
|
||||||
|
|
||||||
return PackedRectangles((total_width, current_height), placed_so_far)
|
return PackedRectangles((total_width, current_height), placed_so_far)
|
||||||
|
|
||||||
|
|
||||||
|
def pack_unique_rectangles(rectangles: List[Rectangle]) -> PackedRectangles:
|
||||||
|
"""
|
||||||
|
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
|
||||||
|
a rectangle by 90 degrees) is allowed. Inputs are deduplicated by their
|
||||||
|
identifier.
|
||||||
|
|
||||||
|
This is a wrapper around pack_rectangles, where inputs come with an
|
||||||
|
identifier. In particular, it calls pack_rectangles for the deduplicated inputs,
|
||||||
|
then returns the values for all the inputs. The output for all rectangles with
|
||||||
|
the same identifier will be the same, except that only the first one will have
|
||||||
|
the is_first field True.
|
||||||
|
|
||||||
|
This is used to join several uv maps into a single scene, see
|
||||||
|
TexturesUV.join_scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rectangles: List of sizes of rectangles to pack
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_size: size of total large rectangle
|
||||||
|
rectangles: location for each of the input rectangles.
|
||||||
|
This includes whether they are flipped.
|
||||||
|
The is_first field is true for the first rectangle
|
||||||
|
with each identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(rectangles) < 2:
|
||||||
|
raise ValueError("Cannot pack less than two boxes")
|
||||||
|
|
||||||
|
input_map = {}
|
||||||
|
input_indices: List[Tuple[int, bool]] = []
|
||||||
|
unique_input_sizes: List[Tuple[int, int]] = []
|
||||||
|
for rectangle in rectangles:
|
||||||
|
if rectangle.identifier not in input_map:
|
||||||
|
unique_index = len(unique_input_sizes)
|
||||||
|
unique_input_sizes.append((rectangle.xsize, rectangle.ysize))
|
||||||
|
input_map[rectangle.identifier] = unique_index
|
||||||
|
input_indices.append((unique_index, True))
|
||||||
|
else:
|
||||||
|
unique_index = input_map[rectangle.identifier]
|
||||||
|
input_indices.append((unique_index, False))
|
||||||
|
|
||||||
|
if len(unique_input_sizes) == 1:
|
||||||
|
first = [PackedRectangle(0, 0, False, True)]
|
||||||
|
rest = (len(rectangles) - 1) * [PackedRectangle(0, 0, False, False)]
|
||||||
|
return PackedRectangles(unique_input_sizes[0], first + rest)
|
||||||
|
|
||||||
|
total_size, unique_locations = pack_rectangles(unique_input_sizes)
|
||||||
|
full_locations = []
|
||||||
|
for input_index, first in input_indices:
|
||||||
|
full_locations.append(unique_locations[input_index]._replace(is_first=first))
|
||||||
|
|
||||||
|
return PackedRectangles(total_size, full_locations)
|
||||||
|
@ -652,6 +652,9 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
verts_shifted2 = verts.clone()
|
verts_shifted2 = verts.clone()
|
||||||
verts_shifted2 *= 0.5
|
verts_shifted2 *= 0.5
|
||||||
verts_shifted2[:, 1] -= 7
|
verts_shifted2[:, 1] -= 7
|
||||||
|
verts_shifted3 = verts.clone()
|
||||||
|
verts_shifted3 *= 0.5
|
||||||
|
verts_shifted3[:, 1] -= 700
|
||||||
|
|
||||||
[faces] = plain_torus.faces_list()
|
[faces] = plain_torus.faces_list()
|
||||||
nocolor = torch.zeros((100, 100), device=device)
|
nocolor = torch.zeros((100, 100), device=device)
|
||||||
@ -697,7 +700,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
|
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
|
||||||
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
|
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
|
||||||
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
|
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
|
||||||
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
|
# mesh4 is like mesh1 but outside the field of view. It is here to test
|
||||||
|
# that having another texture with the same map doesn't produce
|
||||||
|
# two copies in the joined map.
|
||||||
|
mesh4 = Meshes(verts=[verts_shifted3], faces=[faces], textures=textures1)
|
||||||
|
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3, mesh4])
|
||||||
|
|
||||||
output = renderer(mesh)[0, ..., :3].cpu()
|
output = renderer(mesh)[0, ..., :3].cpu()
|
||||||
output1 = renderer(mesh1)[0, ..., :3].cpu()
|
output1 = renderer(mesh1)[0, ..., :3].cpu()
|
||||||
|
@ -12,7 +12,11 @@ from pytorch3d.renderer.mesh.textures import (
|
|||||||
TexturesUV,
|
TexturesUV,
|
||||||
TexturesVertex,
|
TexturesVertex,
|
||||||
_list_to_padded_wrapper,
|
_list_to_padded_wrapper,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.mesh.utils import (
|
||||||
|
Rectangle,
|
||||||
pack_rectangles,
|
pack_rectangles,
|
||||||
|
pack_unique_rectangles,
|
||||||
)
|
)
|
||||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
|
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
|
||||||
from test_meshes import init_mesh
|
from test_meshes import init_mesh
|
||||||
@ -873,21 +877,24 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
|||||||
mask = torch.zeros(total, dtype=torch.bool)
|
mask = torch.zeros(total, dtype=torch.bool)
|
||||||
seen_x_bound = False
|
seen_x_bound = False
|
||||||
seen_y_bound = False
|
seen_y_bound = False
|
||||||
for (in_x, in_y), loc in zip(sizes, res.locations):
|
for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip(
|
||||||
self.assertGreaterEqual(loc[0], 0)
|
sizes, res.locations
|
||||||
self.assertGreaterEqual(loc[1], 0)
|
):
|
||||||
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
|
self.assertTrue(is_first)
|
||||||
upper_x = placed_x + loc[0]
|
self.assertGreaterEqual(out_x, 0)
|
||||||
upper_y = placed_y + loc[1]
|
self.assertGreaterEqual(out_y, 0)
|
||||||
|
placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y)
|
||||||
|
upper_x = placed_x + out_x
|
||||||
|
upper_y = placed_y + out_y
|
||||||
self.assertGreaterEqual(total[0], upper_x)
|
self.assertGreaterEqual(total[0], upper_x)
|
||||||
if total[0] == upper_x:
|
if total[0] == upper_x:
|
||||||
seen_x_bound = True
|
seen_x_bound = True
|
||||||
self.assertGreaterEqual(total[1], upper_y)
|
self.assertGreaterEqual(total[1], upper_y)
|
||||||
if total[1] == upper_y:
|
if total[1] == upper_y:
|
||||||
seen_y_bound = True
|
seen_y_bound = True
|
||||||
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
|
already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y])
|
||||||
self.assertEqual(already_taken, 0)
|
self.assertEqual(already_taken, 0)
|
||||||
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
|
mask[out_x:upper_x, out_y:upper_y] = 1
|
||||||
self.assertTrue(seen_x_bound)
|
self.assertTrue(seen_x_bound)
|
||||||
self.assertTrue(seen_y_bound)
|
self.assertTrue(seen_y_bound)
|
||||||
|
|
||||||
@ -930,3 +937,29 @@ class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
|||||||
for j in range(vals.shape[0]):
|
for j in range(vals.shape[0]):
|
||||||
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
|
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
|
||||||
self.wrap_pack(sizes)
|
self.wrap_pack(sizes)
|
||||||
|
|
||||||
|
def test_all_identical(self):
|
||||||
|
sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3
|
||||||
|
total_size, locations = pack_unique_rectangles(sizes)
|
||||||
|
self.assertEqual(total_size, (61, 82))
|
||||||
|
self.assertEqual(len(locations), 3)
|
||||||
|
for i, (x, y, is_flipped, is_first) in enumerate(locations):
|
||||||
|
self.assertEqual(x, 0)
|
||||||
|
self.assertEqual(y, 0)
|
||||||
|
self.assertFalse(is_flipped)
|
||||||
|
self.assertEqual(is_first, i == 0)
|
||||||
|
|
||||||
|
def test_one_different_id(self):
|
||||||
|
sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3
|
||||||
|
sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3)
|
||||||
|
total_size, locations = pack_unique_rectangles(sizes)
|
||||||
|
self.assertEqual(total_size, (82, 122))
|
||||||
|
self.assertEqual(len(locations), 6)
|
||||||
|
for i, (x, y, is_flipped, is_first) in enumerate(locations):
|
||||||
|
self.assertTrue(is_flipped)
|
||||||
|
self.assertEqual(is_first, i % 3 == 0)
|
||||||
|
self.assertEqual(x, 0)
|
||||||
|
if i < 3:
|
||||||
|
self.assertEqual(y, 61)
|
||||||
|
else:
|
||||||
|
self.assertEqual(y, 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user