amalgamate meshes with texture into a single scene
Summary: Add a join_scene method to all the textures to allow the join_mesh function to include textures. Rename the join_mesh function to join_meshes_as_scene. For TexturesAtlas, we now interpolate if the user attempts to have the resolution vary across the batch. This doesn't look great if the resolution is already very low. For TexturesUV, a rectangle packing function is required, this does something simple. Reviewed By: gkioxari Differential Revision: D23188773 fbshipit-source-id: c013db061a04076e13e90ccc168a7913e933a9c5
@ -2,7 +2,7 @@
 | 
			
		||||
 | 
			
		||||
import itertools
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
@ -10,6 +10,8 @@ from pytorch3d.ops import interpolate_face_attributes
 | 
			
		||||
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
 | 
			
		||||
from torch.nn.functional import interpolate
 | 
			
		||||
 | 
			
		||||
from .utils import pack_rectangles
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This file contains classes and helper functions for texturing.
 | 
			
		||||
# There are three types of textures: TexturesVertex, TexturesAtlas
 | 
			
		||||
@ -329,6 +331,7 @@ class TexturesAtlas(TexturesBase):
 | 
			
		||||
 | 
			
		||||
        [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
 | 
			
		||||
            3D Reasoning', ICCV 2019
 | 
			
		||||
            See also https://github.com/ShichenLiu/SoftRas/issues/21
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(atlas, (list, tuple)):
 | 
			
		||||
            correct_format = all(
 | 
			
		||||
@ -336,11 +339,15 @@ class TexturesAtlas(TexturesBase):
 | 
			
		||||
                    torch.is_tensor(elem)
 | 
			
		||||
                    and elem.ndim == 4
 | 
			
		||||
                    and elem.shape[1] == elem.shape[2]
 | 
			
		||||
                    and elem.shape[1] == atlas[0].shape[1]
 | 
			
		||||
                )
 | 
			
		||||
                for elem in atlas
 | 
			
		||||
            )
 | 
			
		||||
            if not correct_format:
 | 
			
		||||
                msg = "Expected atlas to be a list of tensors of shape (F, R, R, D)"
 | 
			
		||||
                msg = (
 | 
			
		||||
                    "Expected atlas to be a list of tensors of shape (F, R, R, D) "
 | 
			
		||||
                    "with the same value of R."
 | 
			
		||||
                )
 | 
			
		||||
                raise ValueError(msg)
 | 
			
		||||
            self._atlas_list = atlas
 | 
			
		||||
            self._atlas_padded = None
 | 
			
		||||
@ -529,6 +536,12 @@ class TexturesAtlas(TexturesBase):
 | 
			
		||||
        new_tex._num_faces_per_mesh = num_faces_per_mesh
 | 
			
		||||
        return new_tex
 | 
			
		||||
 | 
			
		||||
    def join_scene(self) -> "TexturesAtlas":
 | 
			
		||||
        """
 | 
			
		||||
        Return a new TexturesAtlas amalgamating the batch.
 | 
			
		||||
        """
 | 
			
		||||
        return self.__class__(atlas=[torch.cat(self.atlas_list())])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TexturesUV(TexturesBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -560,7 +573,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        the two align_corners options at
 | 
			
		||||
        https://discuss.pytorch.org/t/22663/9 .
 | 
			
		||||
 | 
			
		||||
        An example of how the indexing into the maps, with align_corners=True
 | 
			
		||||
        An example of how the indexing into the maps, with align_corners=True,
 | 
			
		||||
        works is as follows.
 | 
			
		||||
        If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
 | 
			
		||||
        is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
 | 
			
		||||
@ -574,10 +587,11 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
 | 
			
		||||
        is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
 | 
			
		||||
        whose color is given by maps[i][700, 40].
 | 
			
		||||
        In this case, padding_mode even matters for values in verts_uvs
 | 
			
		||||
        slightly above 0 or slightly below 1. In this case, it matters if the
 | 
			
		||||
        first value is outside the interval [0.0005, 0.9995] or if the second
 | 
			
		||||
        is outside the interval [0.005, 0.995].
 | 
			
		||||
        When align_corners=False, padding_mode even matters for values in
 | 
			
		||||
        verts_uvs slightly above 0 or slightly below 1. In this case, the
 | 
			
		||||
        padding_mode matters if the first value is outside the interval
 | 
			
		||||
        [0.0005, 0.9995] or if the second is outside the interval
 | 
			
		||||
        [0.005, 0.995].
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.padding_mode = padding_mode
 | 
			
		||||
@ -805,12 +819,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
    def maps_padded(self) -> torch.Tensor:
 | 
			
		||||
        return self._maps_padded
 | 
			
		||||
 | 
			
		||||
    def maps_list(self) -> torch.Tensor:
 | 
			
		||||
        # maps_list is not used anywhere currently - maps
 | 
			
		||||
        # are padded to ensure the (H, W) of all maps is the
 | 
			
		||||
        # same across the batch and we don't store the
 | 
			
		||||
        # unpadded sizes of the maps. Therefore just return
 | 
			
		||||
        # the unbinded padded tensor.
 | 
			
		||||
    def maps_list(self) -> List[torch.Tensor]:
 | 
			
		||||
        if self._maps_list is not None:
 | 
			
		||||
            return self._maps_list
 | 
			
		||||
        return self._maps_padded.unbind(0)
 | 
			
		||||
 | 
			
		||||
    def extend(self, N: int) -> "TexturesUV":
 | 
			
		||||
@ -965,6 +976,143 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        new_tex._num_faces_per_mesh = num_faces_per_mesh
 | 
			
		||||
        return new_tex
 | 
			
		||||
 | 
			
		||||
    def _place_map_into_single_map(
 | 
			
		||||
        self,
 | 
			
		||||
        single_map: torch.Tensor,
 | 
			
		||||
        map_: torch.Tensor,
 | 
			
		||||
        location: Tuple[int, int, bool],  # (x,y) and whether flipped
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Copy map into a larger tensor single_map at the destination specified by location.
 | 
			
		||||
        If align_corners is False, we add the needed border around the destination.
 | 
			
		||||
 | 
			
		||||
        Used by join_scene.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            single_map: (total_H, total_W, 3)
 | 
			
		||||
            map_: (H, W, 3) source data
 | 
			
		||||
            location: where to place map
 | 
			
		||||
        """
 | 
			
		||||
        do_flip = location[2]
 | 
			
		||||
        source = map_.transpose(0, 1) if do_flip else map_
 | 
			
		||||
        border_width = 0 if self.align_corners else 1
 | 
			
		||||
        lower_u = location[0] + border_width
 | 
			
		||||
        lower_v = location[1] + border_width
 | 
			
		||||
        upper_u = lower_u + source.shape[0]
 | 
			
		||||
        upper_v = lower_v + source.shape[1]
 | 
			
		||||
        single_map[lower_u:upper_u, lower_v:upper_v] = source
 | 
			
		||||
 | 
			
		||||
        if self.padding_mode != "zeros" and not self.align_corners:
 | 
			
		||||
            single_map[lower_u - 1, lower_v:upper_v] = single_map[
 | 
			
		||||
                lower_u, lower_v:upper_v
 | 
			
		||||
            ]
 | 
			
		||||
            single_map[upper_u, lower_v:upper_v] = single_map[
 | 
			
		||||
                upper_u - 1, lower_v:upper_v
 | 
			
		||||
            ]
 | 
			
		||||
            single_map[lower_u:upper_u, lower_v - 1] = single_map[
 | 
			
		||||
                lower_u:upper_u, lower_v
 | 
			
		||||
            ]
 | 
			
		||||
            single_map[lower_u:upper_u, upper_v] = single_map[
 | 
			
		||||
                lower_u:upper_u, upper_v - 1
 | 
			
		||||
            ]
 | 
			
		||||
            single_map[lower_u - 1, lower_v - 1] = single_map[lower_u, lower_v]
 | 
			
		||||
            single_map[lower_u - 1, upper_v] = single_map[lower_u, upper_v - 1]
 | 
			
		||||
            single_map[upper_u, lower_v - 1] = single_map[upper_u - 1, lower_v]
 | 
			
		||||
            single_map[upper_u, upper_v] = single_map[upper_u - 1, upper_v - 1]
 | 
			
		||||
 | 
			
		||||
    def join_scene(self) -> "TexturesUV":
 | 
			
		||||
        """
 | 
			
		||||
        Return a new TexturesUV amalgamating the batch.
 | 
			
		||||
 | 
			
		||||
        We calculate a large single map which contains the original maps,
 | 
			
		||||
        and find verts_uvs to point into it. This will not replicate
 | 
			
		||||
        behavior of padding for verts_uvs values outside [0,1].
 | 
			
		||||
 | 
			
		||||
        If align_corners=False, we need to add an artificial border around
 | 
			
		||||
        every map.
 | 
			
		||||
 | 
			
		||||
        We use the function `pack_rectangles` to provide a layout for the
 | 
			
		||||
        single map. _place_map_into_single_map is used to copy the maps
 | 
			
		||||
        into the single map. The merging of verts_uvs and faces_uvs are
 | 
			
		||||
        handled locally in this function.
 | 
			
		||||
        """
 | 
			
		||||
        maps = self.maps_list()
 | 
			
		||||
        heights_and_widths = []
 | 
			
		||||
        extra_border = 0 if self.align_corners else 2
 | 
			
		||||
        for map_ in maps:
 | 
			
		||||
            heights_and_widths.append(
 | 
			
		||||
                (map_.shape[0] + extra_border, map_.shape[1] + extra_border)
 | 
			
		||||
            )
 | 
			
		||||
        merging_plan = pack_rectangles(heights_and_widths)
 | 
			
		||||
        # pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
 | 
			
		||||
        single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
 | 
			
		||||
        verts_uvs = self.verts_uvs_list()
 | 
			
		||||
        verts_uvs_merged = []
 | 
			
		||||
 | 
			
		||||
        for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
 | 
			
		||||
            new_uvs = uvs.clone()
 | 
			
		||||
            self._place_map_into_single_map(single_map, map_, loc)
 | 
			
		||||
            do_flip = loc[2]
 | 
			
		||||
            x_shape = map_.shape[1] if do_flip else map_.shape[0]
 | 
			
		||||
            y_shape = map_.shape[0] if do_flip else map_.shape[1]
 | 
			
		||||
 | 
			
		||||
            if do_flip:
 | 
			
		||||
                # Here we have flipped / transposed the map.
 | 
			
		||||
                # In uvs, the y values are decreasing from 1 to 0 and the x
 | 
			
		||||
                # values increase from 0 to 1. We subtract all values from 1
 | 
			
		||||
                # as the x's become y's and the y's become x's.
 | 
			
		||||
                new_uvs = 1.0 - new_uvs[:, [1, 0]]
 | 
			
		||||
                if TYPE_CHECKING:
 | 
			
		||||
                    new_uvs = torch.Tensor(new_uvs)
 | 
			
		||||
 | 
			
		||||
            # If align_corners is True, then an index of x (where x is in
 | 
			
		||||
            # the range 0 .. map_.shape[]-1) in one of the input maps
 | 
			
		||||
            # was hit by a u of x/(map_.shape[]-1).
 | 
			
		||||
            # That x is located at the index loc[] + x in the single_map, and
 | 
			
		||||
            # to hit that we need u to equal (loc[] + x) / (total_size[]-1)
 | 
			
		||||
            # so the old u should be mapped to
 | 
			
		||||
            #   { u*(map_.shape[]-1) + loc[] } / (total_size[]-1)
 | 
			
		||||
 | 
			
		||||
            # If align_corners is False, then an index of x (where x is in
 | 
			
		||||
            # the range 1 .. map_.shape[]-2) in one of the input maps
 | 
			
		||||
            # was hit by a u of (x+0.5)/(map_.shape[]).
 | 
			
		||||
            # That x is located at the index loc[] + 1 + x in the single_map,
 | 
			
		||||
            # (where the 1 is for the border)
 | 
			
		||||
            # and to hit that we need u to equal (loc[] + 1 + x + 0.5) / (total_size[])
 | 
			
		||||
            # so the old u should be mapped to
 | 
			
		||||
            #   { loc[] + 1 + u*map_.shape[]-0.5 + 0.5 } / (total_size[])
 | 
			
		||||
            #  = { loc[] + 1 + u*map_.shape[] } / (total_size[])
 | 
			
		||||
 | 
			
		||||
            # We change the y's in new_uvs for the scaling of height,
 | 
			
		||||
            # and the x's for the scaling of width.
 | 
			
		||||
            # That is why the 1's and 0's are mismatched in these lines.
 | 
			
		||||
            one_if_align = 1 if self.align_corners else 0
 | 
			
		||||
            one_if_not_align = 1 - one_if_align
 | 
			
		||||
            denom_x = merging_plan.total_size[0] - one_if_align
 | 
			
		||||
            scale_x = x_shape - one_if_align
 | 
			
		||||
            denom_y = merging_plan.total_size[1] - one_if_align
 | 
			
		||||
            scale_y = y_shape - one_if_align
 | 
			
		||||
            new_uvs[:, 1] *= scale_x / denom_x
 | 
			
		||||
            new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
 | 
			
		||||
            new_uvs[:, 0] *= scale_y / denom_y
 | 
			
		||||
            new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
 | 
			
		||||
 | 
			
		||||
            verts_uvs_merged.append(new_uvs)
 | 
			
		||||
 | 
			
		||||
        faces_uvs_merged = []
 | 
			
		||||
        offset = 0
 | 
			
		||||
        for faces_uvs_, verts_uvs_ in zip(self.faces_uvs_list(), verts_uvs):
 | 
			
		||||
            faces_uvs_merged.append(offset + faces_uvs_)
 | 
			
		||||
            offset += verts_uvs_.shape[0]
 | 
			
		||||
 | 
			
		||||
        return self.__class__(
 | 
			
		||||
            maps=[single_map],
 | 
			
		||||
            verts_uvs=[torch.cat(verts_uvs_merged)],
 | 
			
		||||
            faces_uvs=[torch.cat(faces_uvs_merged)],
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TexturesVertex(TexturesBase):
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -1156,3 +1304,9 @@ class TexturesVertex(TexturesBase):
 | 
			
		||||
        new_tex = self.__class__(verts_features=verts_features_list)
 | 
			
		||||
        new_tex._num_verts_per_mesh = num_faces_per_mesh
 | 
			
		||||
        return new_tex
 | 
			
		||||
 | 
			
		||||
    def join_scene(self) -> "TexturesVertex":
 | 
			
		||||
        """
 | 
			
		||||
        Return a new TexturesVertex amalgamating the batch.
 | 
			
		||||
        """
 | 
			
		||||
        return self.__class__(verts_features=[torch.cat(self.verts_features_list())])
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import List, NamedTuple, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.ops import interpolate_face_attributes
 | 
			
		||||
 | 
			
		||||
@ -58,3 +60,184 @@ def _interpolate_zbuf(
 | 
			
		||||
    ]  # (1, H, W, K)
 | 
			
		||||
    zbuf[pix_to_face == -1] = -1
 | 
			
		||||
    return zbuf
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# -----------  Rectangle Packing  -------------------- #
 | 
			
		||||
 | 
			
		||||
# Note the order of members matters here because it determines the queue order.
 | 
			
		||||
# We want to place longer rectangles first.
 | 
			
		||||
class _UnplacedRectangle(NamedTuple):
 | 
			
		||||
    size: Tuple[int, int]
 | 
			
		||||
    ind: int
 | 
			
		||||
    flipped: bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _try_place_rectangle(
 | 
			
		||||
    rect: _UnplacedRectangle,
 | 
			
		||||
    placed_so_far: List[Tuple[int, int, bool]],
 | 
			
		||||
    occupied: List[Tuple[int, int]],
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Try to place rect within the current bounding box.
 | 
			
		||||
    Part of the implementation of pack_rectangles.
 | 
			
		||||
 | 
			
		||||
    Note that the arguments `placed_so_far` and `occupied` are modified.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        rect: rectangle to place
 | 
			
		||||
        placed_so_far: the locations decided upon so far - a list of
 | 
			
		||||
                    (x, y, whether flipped). The nth element is the
 | 
			
		||||
                    location of the nth rectangle if it has been decided.
 | 
			
		||||
                    (modified in place)
 | 
			
		||||
        occupied: the nodes of the graph of extents of rightmost placed
 | 
			
		||||
                    rectangles - (modified in place)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        True on success.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
    (We always have placed the first rectangle horizontally and other
 | 
			
		||||
    rectangles above it.)
 | 
			
		||||
    Let's say the placed boxes 1-4 are layed out like this.
 | 
			
		||||
    The coordinates of the points marked X are stored in occupied.
 | 
			
		||||
    It is to the right of the X's that we seek to place rect.
 | 
			
		||||
 | 
			
		||||
     +-----------------------X
 | 
			
		||||
     |2                      |
 | 
			
		||||
     |                       +---X
 | 
			
		||||
     |                       |4  |
 | 
			
		||||
     |                       |   |
 | 
			
		||||
     |                       +---+X
 | 
			
		||||
     |                       |3   |
 | 
			
		||||
     |                       |    |
 | 
			
		||||
     +-----------------------+----+------X
 | 
			
		||||
y    |1                                  |
 | 
			
		||||
^    |     --->x                         |
 | 
			
		||||
|    +-----------------------------------+
 | 
			
		||||
 | 
			
		||||
     We want to place this rectangle.
 | 
			
		||||
 | 
			
		||||
              +-+
 | 
			
		||||
              |5|
 | 
			
		||||
              | |
 | 
			
		||||
              | |   = rect
 | 
			
		||||
              | |
 | 
			
		||||
              | |
 | 
			
		||||
              | |
 | 
			
		||||
              +-+
 | 
			
		||||
 | 
			
		||||
      The call will succeed, returning True, leaving us with
 | 
			
		||||
 | 
			
		||||
      +-----------------------X
 | 
			
		||||
      |2                      |    +-X
 | 
			
		||||
      |                       +---+|5|
 | 
			
		||||
      |                       |4  || |
 | 
			
		||||
      |                       |   || |
 | 
			
		||||
      |                       +---++ |
 | 
			
		||||
      |                       |3   | |
 | 
			
		||||
      |                       |    | |
 | 
			
		||||
      +-----------------------+----+-+----X
 | 
			
		||||
      |1                                  |
 | 
			
		||||
      |                                   |
 | 
			
		||||
      +-----------------------------------+ .
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    total_width = occupied[0][0]
 | 
			
		||||
    needed_height = rect.size[1]
 | 
			
		||||
    current_start_idx = None
 | 
			
		||||
    current_max_width = 0
 | 
			
		||||
    previous_height = 0
 | 
			
		||||
    currently_packed = 0
 | 
			
		||||
    for idx, interval in enumerate(occupied):
 | 
			
		||||
        if interval[0] <= total_width - rect.size[0]:
 | 
			
		||||
            currently_packed += interval[1] - previous_height
 | 
			
		||||
            current_max_width = max(interval[0], current_max_width)
 | 
			
		||||
            if current_start_idx is None:
 | 
			
		||||
                current_start_idx = idx
 | 
			
		||||
            if currently_packed >= needed_height:
 | 
			
		||||
                current_max_width = max(interval[0], current_max_width)
 | 
			
		||||
                placed_so_far[rect.ind] = (
 | 
			
		||||
                    current_max_width,
 | 
			
		||||
                    occupied[current_start_idx - 1][1],
 | 
			
		||||
                    rect.flipped,
 | 
			
		||||
                )
 | 
			
		||||
                new_occupied = (
 | 
			
		||||
                    current_max_width + rect.size[0],
 | 
			
		||||
                    occupied[current_start_idx - 1][1] + needed_height,
 | 
			
		||||
                )
 | 
			
		||||
                if currently_packed == needed_height:
 | 
			
		||||
                    occupied[idx] = new_occupied
 | 
			
		||||
                    del occupied[current_start_idx:idx]
 | 
			
		||||
                elif idx > current_start_idx:
 | 
			
		||||
                    occupied[idx - 1] = new_occupied
 | 
			
		||||
                    del occupied[current_start_idx : (idx - 1)]
 | 
			
		||||
                else:
 | 
			
		||||
                    occupied.insert(idx, new_occupied)
 | 
			
		||||
                return True
 | 
			
		||||
        else:
 | 
			
		||||
            current_start_idx = None
 | 
			
		||||
            current_max_width = 0
 | 
			
		||||
            currently_packed = 0
 | 
			
		||||
        previous_height = interval[1]
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PackedRectangles(NamedTuple):
 | 
			
		||||
    total_size: Tuple[int, int]
 | 
			
		||||
    locations: List[Tuple[int, int, bool]]  # (x,y) and whether flipped
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
 | 
			
		||||
    """
 | 
			
		||||
    Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
 | 
			
		||||
    a rectangle by 90 degrees) is allowed.
 | 
			
		||||
 | 
			
		||||
    This is used to join several uv maps into a single scene, see
 | 
			
		||||
    TexturesUV.join_scene.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        sizes: List of sizes of rectangles to pack
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        total_size: size of total large rectangle
 | 
			
		||||
        rectangles: location for each of the input rectangles
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if len(sizes) < 2:
 | 
			
		||||
        raise ValueError("Cannot pack less than two boxes")
 | 
			
		||||
 | 
			
		||||
    queue = []
 | 
			
		||||
    for i, size in enumerate(sizes):
 | 
			
		||||
        if size[0] < size[1]:
 | 
			
		||||
            queue.append(_UnplacedRectangle((size[1], size[0]), i, True))
 | 
			
		||||
        else:
 | 
			
		||||
            queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
 | 
			
		||||
    queue.sort()
 | 
			
		||||
    placed_so_far = [(-1, -1, False)] * len(sizes)
 | 
			
		||||
 | 
			
		||||
    biggest = queue.pop()
 | 
			
		||||
    total_width, current_height = biggest.size
 | 
			
		||||
    placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
 | 
			
		||||
 | 
			
		||||
    second = queue.pop()
 | 
			
		||||
    placed_so_far[second.ind] = (0, current_height, second.flipped)
 | 
			
		||||
    current_height += second.size[1]
 | 
			
		||||
    occupied = [biggest.size, (second.size[0], current_height)]
 | 
			
		||||
 | 
			
		||||
    for rect in reversed(queue):
 | 
			
		||||
        if _try_place_rectangle(rect, placed_so_far, occupied):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        rotated = _UnplacedRectangle(
 | 
			
		||||
            (rect.size[1], rect.size[0]), rect.ind, not rect.flipped
 | 
			
		||||
        )
 | 
			
		||||
        if _try_place_rectangle(rotated, placed_so_far, occupied):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # rect wasn't placed in the current bounding box,
 | 
			
		||||
        # so we add extra space to fit it in.
 | 
			
		||||
        placed_so_far[rect.ind] = (0, current_height, rect.flipped)
 | 
			
		||||
        current_height += rect.size[1]
 | 
			
		||||
        occupied.append((rect.size[0], current_height))
 | 
			
		||||
 | 
			
		||||
    return PackedRectangles((total_width, current_height), placed_so_far)
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from .meshes import Meshes, join_meshes_as_batch
 | 
			
		||||
from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene
 | 
			
		||||
from .pointclouds import Pointclouds
 | 
			
		||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1254,7 +1254,7 @@ class Meshes(object):
 | 
			
		||||
        """
 | 
			
		||||
        verts_packed = self.verts_packed()
 | 
			
		||||
        if vert_offsets_packed.shape != verts_packed.shape:
 | 
			
		||||
            raise ValueError("Verts offsets must have dimension (all_v, 2).")
 | 
			
		||||
            raise ValueError("Verts offsets must have dimension (all_v, 3).")
 | 
			
		||||
        # update verts packed
 | 
			
		||||
        self._verts_packed = verts_packed + vert_offsets_packed
 | 
			
		||||
        new_verts_list = list(
 | 
			
		||||
@ -1548,26 +1548,43 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
 | 
			
		||||
    return Meshes(verts=verts, faces=faces, textures=tex)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
 | 
			
		||||
def join_meshes_as_scene(
 | 
			
		||||
    meshes: Union[Meshes, List[Meshes]], include_textures: bool = True
 | 
			
		||||
) -> Meshes:
 | 
			
		||||
    """
 | 
			
		||||
    Joins a batch of meshes in the form of a Meshes object or a list of Meshes
 | 
			
		||||
    objects as a single mesh. If the input is a list, the Meshes objects in the list
 | 
			
		||||
    must all be on the same device. This version ignores all textures in the input meshes.
 | 
			
		||||
    objects as a single mesh. If the input is a list, the Meshes objects in the
 | 
			
		||||
    list must all be on the same device. Unless include_textures is False, the
 | 
			
		||||
    meshes must all have the same type of texture or must all not have textures.
 | 
			
		||||
 | 
			
		||||
    If textures are included, then the textures are joined as a single scene in
 | 
			
		||||
    addition to the meshes. For this, texture types have an appropriate method
 | 
			
		||||
    called join_scene which joins mesh textures into a single texture.
 | 
			
		||||
    If the textures are TexturesAtlas then they must have the same resolution.
 | 
			
		||||
    If they are TexturesUV then they must have the same align_corners and
 | 
			
		||||
    padding_mode. Values in verts_uvs outside [0, 1] will not
 | 
			
		||||
    be respected.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
 | 
			
		||||
        meshes: Meshes object that contains a batch of meshes, or a list of
 | 
			
		||||
                    Meshes objects.
 | 
			
		||||
        include_textures: (bool) whether to try to join the textures.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        new Meshes object containing a single mesh
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(meshes, List):
 | 
			
		||||
        meshes = join_meshes_as_batch(meshes, include_textures=False)
 | 
			
		||||
        meshes = join_meshes_as_batch(meshes, include_textures=include_textures)
 | 
			
		||||
 | 
			
		||||
    if len(meshes) == 1:
 | 
			
		||||
        return meshes
 | 
			
		||||
    verts = meshes.verts_packed()  # (sum(V_n), 3)
 | 
			
		||||
    # Offset automatically done by faces_packed
 | 
			
		||||
    faces = meshes.faces_packed()  # (sum(F_n), 3)
 | 
			
		||||
    textures = None
 | 
			
		||||
 | 
			
		||||
    mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
 | 
			
		||||
    if include_textures and meshes.textures is not None:
 | 
			
		||||
        textures = meshes.textures.join_scene()
 | 
			
		||||
 | 
			
		||||
    mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0), textures=textures)
 | 
			
		||||
    return mesh
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinatlas_final.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 25 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs0_final.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 12 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs0_map.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 807 B  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs1_final.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 12 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs1_map.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 819 B  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs2_final.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 11 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinuvs2_map.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 806 B  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_joinverts_final.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						| 
		 After Width: | Height: | Size: 11 KiB  | 
@ -33,7 +33,11 @@ from pytorch3d.renderer.mesh.shader import (
 | 
			
		||||
    SoftSilhouetteShader,
 | 
			
		||||
    TexturedSoftPhongShader,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch
 | 
			
		||||
from pytorch3d.structures.meshes import (
 | 
			
		||||
    Meshes,
 | 
			
		||||
    join_meshes_as_batch,
 | 
			
		||||
    join_meshes_as_scene,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
from pytorch3d.utils.torus import torus
 | 
			
		||||
 | 
			
		||||
@ -571,6 +575,288 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
 | 
			
		||||
        self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)
 | 
			
		||||
 | 
			
		||||
    def test_join_uvs(self):
 | 
			
		||||
        """Meshes with TexturesUV joined into a scene"""
 | 
			
		||||
        # Test the result of rendering three tori with separate textures.
 | 
			
		||||
        # The expected result is consistent with rendering them each alone.
 | 
			
		||||
        # This tests TexturesUV.join_scene with rectangle flipping,
 | 
			
		||||
        # and we check the form of the merged map as well.
 | 
			
		||||
        torch.manual_seed(1)
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        R, T = look_at_view_transform(18, 0, 0)
 | 
			
		||||
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=256, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=((1.0, 1.0, 1.0),),
 | 
			
		||||
            diffuse_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
            specular_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
        )
 | 
			
		||||
        blend_params = BlendParams(
 | 
			
		||||
            sigma=1e-1,
 | 
			
		||||
            gamma=1e-4,
 | 
			
		||||
            background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
 | 
			
		||||
        )
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=HardPhongShader(
 | 
			
		||||
                device=device, blend_params=blend_params, cameras=cameras, lights=lights
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
 | 
			
		||||
        [verts] = plain_torus.verts_list()
 | 
			
		||||
        verts_shifted1 = verts.clone()
 | 
			
		||||
        verts_shifted1 *= 0.5
 | 
			
		||||
        verts_shifted1[:, 1] += 7
 | 
			
		||||
        verts_shifted2 = verts.clone()
 | 
			
		||||
        verts_shifted2 *= 0.5
 | 
			
		||||
        verts_shifted2[:, 1] -= 7
 | 
			
		||||
 | 
			
		||||
        [faces] = plain_torus.faces_list()
 | 
			
		||||
        nocolor = torch.zeros((100, 100), device=device)
 | 
			
		||||
        color_gradient = torch.linspace(0, 1, steps=100, device=device)
 | 
			
		||||
        color_gradient1 = color_gradient[None].expand_as(nocolor)
 | 
			
		||||
        color_gradient2 = color_gradient[:, None].expand_as(nocolor)
 | 
			
		||||
        colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
 | 
			
		||||
        colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
 | 
			
		||||
        verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
 | 
			
		||||
        verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)
 | 
			
		||||
 | 
			
		||||
        for i, align_corners, padding_mode in [
 | 
			
		||||
            (0, True, "border"),
 | 
			
		||||
            (1, False, "border"),
 | 
			
		||||
            (2, False, "zeros"),
 | 
			
		||||
        ]:
 | 
			
		||||
            textures1 = TexturesUV(
 | 
			
		||||
                maps=[colors1],
 | 
			
		||||
                faces_uvs=[faces],
 | 
			
		||||
                verts_uvs=[verts_uvs1],
 | 
			
		||||
                align_corners=align_corners,
 | 
			
		||||
                padding_mode=padding_mode,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # These downsamplings of colors2 are chosen to ensure a flip and a non flip
 | 
			
		||||
            # when the maps are merged.
 | 
			
		||||
            # We have maps of size (100, 100), (50, 99) and (99, 50).
 | 
			
		||||
            textures2 = TexturesUV(
 | 
			
		||||
                maps=[colors2[::2, :-1]],
 | 
			
		||||
                faces_uvs=[faces],
 | 
			
		||||
                verts_uvs=[verts_uvs2],
 | 
			
		||||
                align_corners=align_corners,
 | 
			
		||||
                padding_mode=padding_mode,
 | 
			
		||||
            )
 | 
			
		||||
            offset = torch.tensor([0, 0, 0.5], device=device)
 | 
			
		||||
            textures3 = TexturesUV(
 | 
			
		||||
                maps=[colors2[:-1, ::2] + offset],
 | 
			
		||||
                faces_uvs=[faces],
 | 
			
		||||
                verts_uvs=[verts_uvs2],
 | 
			
		||||
                align_corners=align_corners,
 | 
			
		||||
                padding_mode=padding_mode,
 | 
			
		||||
            )
 | 
			
		||||
            mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
 | 
			
		||||
            mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
 | 
			
		||||
            mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
 | 
			
		||||
            mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
 | 
			
		||||
 | 
			
		||||
            output = renderer(mesh)[0, ..., :3].cpu()
 | 
			
		||||
            output1 = renderer(mesh1)[0, ..., :3].cpu()
 | 
			
		||||
            output2 = renderer(mesh2)[0, ..., :3].cpu()
 | 
			
		||||
            output3 = renderer(mesh3)[0, ..., :3].cpu()
 | 
			
		||||
            # The background color is white and the objects do not overlap, so we can
 | 
			
		||||
            # predict the merged image by taking the minimum over every channel
 | 
			
		||||
            merged = torch.min(torch.min(output1, output2), output3)
 | 
			
		||||
 | 
			
		||||
            image_ref = load_rgb_image(f"test_joinuvs{i}_final.png", DATA_DIR)
 | 
			
		||||
            map_ref = load_rgb_image(f"test_joinuvs{i}_map.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / f"test_joinuvs{i}_final_.png"
 | 
			
		||||
                )
 | 
			
		||||
                Image.fromarray((output.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / f"test_joinuvs{i}_merged.png"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                Image.fromarray((output1.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / f"test_joinuvs{i}_1.png"
 | 
			
		||||
                )
 | 
			
		||||
                Image.fromarray((output2.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / f"test_joinuvs{i}_2.png"
 | 
			
		||||
                )
 | 
			
		||||
                Image.fromarray((output3.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / f"test_joinuvs{i}_3.png"
 | 
			
		||||
                )
 | 
			
		||||
                Image.fromarray(
 | 
			
		||||
                    (mesh.textures.maps_padded()[0].cpu().numpy() * 255).astype(
 | 
			
		||||
                        np.uint8
 | 
			
		||||
                    )
 | 
			
		||||
                ).save(DATA_DIR / f"test_joinuvs{i}_map_.png")
 | 
			
		||||
                Image.fromarray(
 | 
			
		||||
                    (mesh2.textures.maps_padded()[0].cpu().numpy() * 255).astype(
 | 
			
		||||
                        np.uint8
 | 
			
		||||
                    )
 | 
			
		||||
                ).save(DATA_DIR / f"test_joinuvs{i}_map2.png")
 | 
			
		||||
                Image.fromarray(
 | 
			
		||||
                    (mesh3.textures.maps_padded()[0].cpu().numpy() * 255).astype(
 | 
			
		||||
                        np.uint8
 | 
			
		||||
                    )
 | 
			
		||||
                ).save(DATA_DIR / f"test_joinuvs{i}_map3.png")
 | 
			
		||||
 | 
			
		||||
            self.assertClose(output, merged, atol=0.015)
 | 
			
		||||
            self.assertClose(output, image_ref, atol=0.05)
 | 
			
		||||
            self.assertClose(mesh.textures.maps_padded()[0].cpu(), map_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_join_verts(self):
 | 
			
		||||
        """Meshes with TexturesVertex joined into a scene"""
 | 
			
		||||
        # Test the result of rendering two tori with separate textures.
 | 
			
		||||
        # The expected result is consistent with rendering them each alone.
 | 
			
		||||
        torch.manual_seed(1)
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
 | 
			
		||||
        [verts] = plain_torus.verts_list()
 | 
			
		||||
        verts_shifted1 = verts.clone()
 | 
			
		||||
        verts_shifted1 *= 0.5
 | 
			
		||||
        verts_shifted1[:, 1] += 7
 | 
			
		||||
 | 
			
		||||
        faces = plain_torus.faces_list()
 | 
			
		||||
        textures1 = TexturesVertex(verts_features=[torch.rand_like(verts)])
 | 
			
		||||
        textures2 = TexturesVertex(verts_features=[torch.rand_like(verts)])
 | 
			
		||||
        mesh1 = Meshes(verts=[verts], faces=faces, textures=textures1)
 | 
			
		||||
        mesh2 = Meshes(verts=[verts_shifted1], faces=faces, textures=textures2)
 | 
			
		||||
        mesh = join_meshes_as_scene([mesh1, mesh2])
 | 
			
		||||
 | 
			
		||||
        R, T = look_at_view_transform(18, 0, 0)
 | 
			
		||||
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=256, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=((1.0, 1.0, 1.0),),
 | 
			
		||||
            diffuse_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
            specular_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
        )
 | 
			
		||||
        blend_params = BlendParams(
 | 
			
		||||
            sigma=1e-1,
 | 
			
		||||
            gamma=1e-4,
 | 
			
		||||
            background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
 | 
			
		||||
        )
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=HardPhongShader(
 | 
			
		||||
                device=device, blend_params=blend_params, cameras=cameras, lights=lights
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        output = renderer(mesh)
 | 
			
		||||
 | 
			
		||||
        image_ref = load_rgb_image("test_joinverts_final.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            debugging_outputs = []
 | 
			
		||||
            for mesh_ in [mesh1, mesh2]:
 | 
			
		||||
                debugging_outputs.append(renderer(mesh_))
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinverts_final_.png")
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinverts_1.png")
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinverts_2.png")
 | 
			
		||||
 | 
			
		||||
        result = output[0, ..., :3].cpu()
 | 
			
		||||
        self.assertClose(result, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_join_atlas(self):
 | 
			
		||||
        """Meshes with TexturesAtlas joined into a scene"""
 | 
			
		||||
        # Test the result of rendering two tori with separate textures.
 | 
			
		||||
        # The expected result is consistent with rendering them each alone.
 | 
			
		||||
        torch.manual_seed(1)
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        plain_torus = torus(r=1, R=4, sides=5, rings=6, device=device)
 | 
			
		||||
        [verts] = plain_torus.verts_list()
 | 
			
		||||
        verts_shifted1 = verts.clone()
 | 
			
		||||
        verts_shifted1 *= 1.2
 | 
			
		||||
        verts_shifted1[:, 0] += 4
 | 
			
		||||
        verts_shifted1[:, 1] += 5
 | 
			
		||||
        verts[:, 0] -= 4
 | 
			
		||||
        verts[:, 1] -= 4
 | 
			
		||||
 | 
			
		||||
        [faces] = plain_torus.faces_list()
 | 
			
		||||
        map_size = 3
 | 
			
		||||
        # Two random atlases.
 | 
			
		||||
        # The averaging of the random numbers here is not consistent with the
 | 
			
		||||
        # meaning of the atlases, but makes each face a bit smoother than
 | 
			
		||||
        # if everything had a random color.
 | 
			
		||||
        atlas1 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
 | 
			
		||||
        atlas1[:, 1] = 0.5 * atlas1[:, 0] + 0.5 * atlas1[:, 2]
 | 
			
		||||
        atlas1[:, :, 1] = 0.5 * atlas1[:, :, 0] + 0.5 * atlas1[:, :, 2]
 | 
			
		||||
        atlas2 = torch.rand(size=(faces.shape[0], map_size, map_size, 3), device=device)
 | 
			
		||||
        atlas2[:, 1] = 0.5 * atlas2[:, 0] + 0.5 * atlas2[:, 2]
 | 
			
		||||
        atlas2[:, :, 1] = 0.5 * atlas2[:, :, 0] + 0.5 * atlas2[:, :, 2]
 | 
			
		||||
 | 
			
		||||
        textures1 = TexturesAtlas(atlas=[atlas1])
 | 
			
		||||
        textures2 = TexturesAtlas(atlas=[atlas2])
 | 
			
		||||
        mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
 | 
			
		||||
        mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
 | 
			
		||||
        mesh_joined = join_meshes_as_scene([mesh1, mesh2])
 | 
			
		||||
 | 
			
		||||
        R, T = look_at_view_transform(18, 0, 0)
 | 
			
		||||
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        lights = PointLights(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=((1.0, 1.0, 1.0),),
 | 
			
		||||
            diffuse_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
            specular_color=((0.0, 0.0, 0.0),),
 | 
			
		||||
        )
 | 
			
		||||
        blend_params = BlendParams(
 | 
			
		||||
            sigma=1e-1,
 | 
			
		||||
            gamma=1e-4,
 | 
			
		||||
            background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
 | 
			
		||||
        )
 | 
			
		||||
        renderer = MeshRenderer(
 | 
			
		||||
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
 | 
			
		||||
            shader=HardPhongShader(
 | 
			
		||||
                device=device, blend_params=blend_params, cameras=cameras, lights=lights
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        output = renderer(mesh_joined)
 | 
			
		||||
 | 
			
		||||
        image_ref = load_rgb_image("test_joinatlas_final.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            debugging_outputs = []
 | 
			
		||||
            for mesh_ in [mesh1, mesh2]:
 | 
			
		||||
                debugging_outputs.append(renderer(mesh_))
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (output[0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinatlas_final_.png")
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (debugging_outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinatlas_1.png")
 | 
			
		||||
            Image.fromarray(
 | 
			
		||||
                (debugging_outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
 | 
			
		||||
            ).save(DATA_DIR / "test_joinatlas_2.png")
 | 
			
		||||
 | 
			
		||||
        result = output[0, ..., :3].cpu()
 | 
			
		||||
        self.assertClose(result, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_joined_spheres(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test a list of Meshes can be joined as a single mesh and
 | 
			
		||||
@ -595,7 +881,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            sphere_mesh_list.append(
 | 
			
		||||
                Meshes(verts=verts, faces=sphere_list[i].faces_padded())
 | 
			
		||||
            )
 | 
			
		||||
        joined_sphere_mesh = join_mesh(sphere_mesh_list)
 | 
			
		||||
        joined_sphere_mesh = join_meshes_as_scene(sphere_mesh_list)
 | 
			
		||||
        joined_sphere_mesh.textures = TexturesVertex(
 | 
			
		||||
            verts_features=torch.ones_like(joined_sphere_mesh.verts_padded())
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.textures import (
 | 
			
		||||
    TexturesUV,
 | 
			
		||||
    TexturesVertex,
 | 
			
		||||
    _list_to_padded_wrapper,
 | 
			
		||||
    pack_rectangles,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
 | 
			
		||||
from test_meshes import TestMeshes
 | 
			
		||||
@ -730,3 +731,80 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64)
 | 
			
		||||
        tryindex(self, index, tex, meshes, source)
 | 
			
		||||
        tryindex(self, [2, 4], tex, meshes, source)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
    def wrap_pack(self, sizes):
 | 
			
		||||
        """
 | 
			
		||||
        Call the pack_rectangles function, which we want to test,
 | 
			
		||||
        and return its outputs.
 | 
			
		||||
        Additionally makes some sanity checks on the output.
 | 
			
		||||
        """
 | 
			
		||||
        res = pack_rectangles(sizes)
 | 
			
		||||
        total = res.total_size
 | 
			
		||||
        self.assertGreaterEqual(total[0], 0)
 | 
			
		||||
        self.assertGreaterEqual(total[1], 0)
 | 
			
		||||
        mask = torch.zeros(total, dtype=torch.bool)
 | 
			
		||||
        seen_x_bound = False
 | 
			
		||||
        seen_y_bound = False
 | 
			
		||||
        for (in_x, in_y), loc in zip(sizes, res.locations):
 | 
			
		||||
            self.assertGreaterEqual(loc[0], 0)
 | 
			
		||||
            self.assertGreaterEqual(loc[1], 0)
 | 
			
		||||
            placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
 | 
			
		||||
            upper_x = placed_x + loc[0]
 | 
			
		||||
            upper_y = placed_y + loc[1]
 | 
			
		||||
            self.assertGreaterEqual(total[0], upper_x)
 | 
			
		||||
            if total[0] == upper_x:
 | 
			
		||||
                seen_x_bound = True
 | 
			
		||||
            self.assertGreaterEqual(total[1], upper_y)
 | 
			
		||||
            if total[1] == upper_y:
 | 
			
		||||
                seen_y_bound = True
 | 
			
		||||
            already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
 | 
			
		||||
            self.assertEqual(already_taken, 0)
 | 
			
		||||
            mask[loc[0] : upper_x, loc[1] : upper_y] = 1
 | 
			
		||||
        self.assertTrue(seen_x_bound)
 | 
			
		||||
        self.assertTrue(seen_y_bound)
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.all(torch.sum(mask, dim=0, dtype=torch.int32) > 0))
 | 
			
		||||
        self.assertTrue(torch.all(torch.sum(mask, dim=1, dtype=torch.int32) > 0))
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def assert_bb(self, sizes, expected):
 | 
			
		||||
        """
 | 
			
		||||
        Apply the pack_rectangles function to sizes and verify the
 | 
			
		||||
        bounding box dimensions are expected.
 | 
			
		||||
        """
 | 
			
		||||
        self.assertSetEqual(set(self.wrap_pack(sizes).total_size), expected)
 | 
			
		||||
 | 
			
		||||
    def test_simple(self):
 | 
			
		||||
        self.assert_bb([(3, 4), (4, 3)], {6, 4})
 | 
			
		||||
        self.assert_bb([(2, 2), (2, 4), (2, 2)], {4, 4})
 | 
			
		||||
 | 
			
		||||
        # many squares
 | 
			
		||||
        self.assert_bb([(2, 2)] * 9, {2, 18})
 | 
			
		||||
 | 
			
		||||
        # One big square and many small ones.
 | 
			
		||||
        self.assert_bb([(3, 3)] + [(1, 1)] * 2, {3, 4})
 | 
			
		||||
        self.assert_bb([(3, 3)] + [(1, 1)] * 3, {3, 4})
 | 
			
		||||
        self.assert_bb([(3, 3)] + [(1, 1)] * 4, {3, 5})
 | 
			
		||||
        self.assert_bb([(3, 3)] + [(1, 1)] * 5, {3, 5})
 | 
			
		||||
        self.assert_bb([(1, 1)] * 6 + [(3, 3)], {3, 5})
 | 
			
		||||
        self.assert_bb([(3, 3)] + [(1, 1)] * 7, {3, 6})
 | 
			
		||||
 | 
			
		||||
        # many identical rectangles
 | 
			
		||||
        self.assert_bb([(7, 190)] * 4 + [(190, 7)] * 4, {190, 56})
 | 
			
		||||
 | 
			
		||||
        # require placing the flipped version of a rectangle
 | 
			
		||||
        self.assert_bb([(1, 100), (5, 96), (4, 5)], {100, 6})
 | 
			
		||||
 | 
			
		||||
    def test_random(self):
 | 
			
		||||
        for _ in range(5):
 | 
			
		||||
            vals = torch.randint(size=(20, 2), low=1, high=18)
 | 
			
		||||
            sizes = []
 | 
			
		||||
            for j in range(vals.shape[0]):
 | 
			
		||||
                sizes.append((int(vals[j, 0]), int(vals[j, 1])))
 | 
			
		||||
            self.wrap_pack(sizes)
 | 
			
		||||
 | 
			
		||||