mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
amalgamate meshes with texture into a single scene
Summary: Add a join_scene method to all the textures to allow the join_mesh function to include textures. Rename the join_mesh function to join_meshes_as_scene. For TexturesAtlas, we now interpolate if the user attempts to have the resolution vary across the batch. This doesn't look great if the resolution is already very low. For TexturesUV, a rectangle packing function is required, this does something simple. Reviewed By: gkioxari Differential Revision: D23188773 fbshipit-source-id: c013db061a04076e13e90ccc168a7913e933a9c5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e25ccab3d9
commit
909dc83505
@@ -2,7 +2,7 @@
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -10,6 +10,8 @@ from pytorch3d.ops import interpolate_face_attributes
|
||||
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from .utils import pack_rectangles
|
||||
|
||||
|
||||
# This file contains classes and helper functions for texturing.
|
||||
# There are three types of textures: TexturesVertex, TexturesAtlas
|
||||
@@ -329,6 +331,7 @@ class TexturesAtlas(TexturesBase):
|
||||
|
||||
[1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
|
||||
3D Reasoning', ICCV 2019
|
||||
See also https://github.com/ShichenLiu/SoftRas/issues/21
|
||||
"""
|
||||
if isinstance(atlas, (list, tuple)):
|
||||
correct_format = all(
|
||||
@@ -336,11 +339,15 @@ class TexturesAtlas(TexturesBase):
|
||||
torch.is_tensor(elem)
|
||||
and elem.ndim == 4
|
||||
and elem.shape[1] == elem.shape[2]
|
||||
and elem.shape[1] == atlas[0].shape[1]
|
||||
)
|
||||
for elem in atlas
|
||||
)
|
||||
if not correct_format:
|
||||
msg = "Expected atlas to be a list of tensors of shape (F, R, R, D)"
|
||||
msg = (
|
||||
"Expected atlas to be a list of tensors of shape (F, R, R, D) "
|
||||
"with the same value of R."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self._atlas_list = atlas
|
||||
self._atlas_padded = None
|
||||
@@ -529,6 +536,12 @@ class TexturesAtlas(TexturesBase):
|
||||
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
||||
return new_tex
|
||||
|
||||
def join_scene(self) -> "TexturesAtlas":
|
||||
"""
|
||||
Return a new TexturesAtlas amalgamating the batch.
|
||||
"""
|
||||
return self.__class__(atlas=[torch.cat(self.atlas_list())])
|
||||
|
||||
|
||||
class TexturesUV(TexturesBase):
|
||||
def __init__(
|
||||
@@ -560,7 +573,7 @@ class TexturesUV(TexturesBase):
|
||||
the two align_corners options at
|
||||
https://discuss.pytorch.org/t/22663/9 .
|
||||
|
||||
An example of how the indexing into the maps, with align_corners=True
|
||||
An example of how the indexing into the maps, with align_corners=True,
|
||||
works is as follows.
|
||||
If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
|
||||
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
|
||||
@@ -574,10 +587,11 @@ class TexturesUV(TexturesBase):
|
||||
If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
|
||||
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
|
||||
whose color is given by maps[i][700, 40].
|
||||
In this case, padding_mode even matters for values in verts_uvs
|
||||
slightly above 0 or slightly below 1. In this case, it matters if the
|
||||
first value is outside the interval [0.0005, 0.9995] or if the second
|
||||
is outside the interval [0.005, 0.995].
|
||||
When align_corners=False, padding_mode even matters for values in
|
||||
verts_uvs slightly above 0 or slightly below 1. In this case, the
|
||||
padding_mode matters if the first value is outside the interval
|
||||
[0.0005, 0.9995] or if the second is outside the interval
|
||||
[0.005, 0.995].
|
||||
"""
|
||||
super().__init__()
|
||||
self.padding_mode = padding_mode
|
||||
@@ -805,12 +819,9 @@ class TexturesUV(TexturesBase):
|
||||
def maps_padded(self) -> torch.Tensor:
|
||||
return self._maps_padded
|
||||
|
||||
def maps_list(self) -> torch.Tensor:
|
||||
# maps_list is not used anywhere currently - maps
|
||||
# are padded to ensure the (H, W) of all maps is the
|
||||
# same across the batch and we don't store the
|
||||
# unpadded sizes of the maps. Therefore just return
|
||||
# the unbinded padded tensor.
|
||||
def maps_list(self) -> List[torch.Tensor]:
|
||||
if self._maps_list is not None:
|
||||
return self._maps_list
|
||||
return self._maps_padded.unbind(0)
|
||||
|
||||
def extend(self, N: int) -> "TexturesUV":
|
||||
@@ -965,6 +976,143 @@ class TexturesUV(TexturesBase):
|
||||
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
||||
return new_tex
|
||||
|
||||
def _place_map_into_single_map(
|
||||
self,
|
||||
single_map: torch.Tensor,
|
||||
map_: torch.Tensor,
|
||||
location: Tuple[int, int, bool], # (x,y) and whether flipped
|
||||
) -> None:
|
||||
"""
|
||||
Copy map into a larger tensor single_map at the destination specified by location.
|
||||
If align_corners is False, we add the needed border around the destination.
|
||||
|
||||
Used by join_scene.
|
||||
|
||||
Args:
|
||||
single_map: (total_H, total_W, 3)
|
||||
map_: (H, W, 3) source data
|
||||
location: where to place map
|
||||
"""
|
||||
do_flip = location[2]
|
||||
source = map_.transpose(0, 1) if do_flip else map_
|
||||
border_width = 0 if self.align_corners else 1
|
||||
lower_u = location[0] + border_width
|
||||
lower_v = location[1] + border_width
|
||||
upper_u = lower_u + source.shape[0]
|
||||
upper_v = lower_v + source.shape[1]
|
||||
single_map[lower_u:upper_u, lower_v:upper_v] = source
|
||||
|
||||
if self.padding_mode != "zeros" and not self.align_corners:
|
||||
single_map[lower_u - 1, lower_v:upper_v] = single_map[
|
||||
lower_u, lower_v:upper_v
|
||||
]
|
||||
single_map[upper_u, lower_v:upper_v] = single_map[
|
||||
upper_u - 1, lower_v:upper_v
|
||||
]
|
||||
single_map[lower_u:upper_u, lower_v - 1] = single_map[
|
||||
lower_u:upper_u, lower_v
|
||||
]
|
||||
single_map[lower_u:upper_u, upper_v] = single_map[
|
||||
lower_u:upper_u, upper_v - 1
|
||||
]
|
||||
single_map[lower_u - 1, lower_v - 1] = single_map[lower_u, lower_v]
|
||||
single_map[lower_u - 1, upper_v] = single_map[lower_u, upper_v - 1]
|
||||
single_map[upper_u, lower_v - 1] = single_map[upper_u - 1, lower_v]
|
||||
single_map[upper_u, upper_v] = single_map[upper_u - 1, upper_v - 1]
|
||||
|
||||
def join_scene(self) -> "TexturesUV":
|
||||
"""
|
||||
Return a new TexturesUV amalgamating the batch.
|
||||
|
||||
We calculate a large single map which contains the original maps,
|
||||
and find verts_uvs to point into it. This will not replicate
|
||||
behavior of padding for verts_uvs values outside [0,1].
|
||||
|
||||
If align_corners=False, we need to add an artificial border around
|
||||
every map.
|
||||
|
||||
We use the function `pack_rectangles` to provide a layout for the
|
||||
single map. _place_map_into_single_map is used to copy the maps
|
||||
into the single map. The merging of verts_uvs and faces_uvs are
|
||||
handled locally in this function.
|
||||
"""
|
||||
maps = self.maps_list()
|
||||
heights_and_widths = []
|
||||
extra_border = 0 if self.align_corners else 2
|
||||
for map_ in maps:
|
||||
heights_and_widths.append(
|
||||
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
|
||||
)
|
||||
merging_plan = pack_rectangles(heights_and_widths)
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
|
||||
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
|
||||
verts_uvs = self.verts_uvs_list()
|
||||
verts_uvs_merged = []
|
||||
|
||||
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
|
||||
new_uvs = uvs.clone()
|
||||
self._place_map_into_single_map(single_map, map_, loc)
|
||||
do_flip = loc[2]
|
||||
x_shape = map_.shape[1] if do_flip else map_.shape[0]
|
||||
y_shape = map_.shape[0] if do_flip else map_.shape[1]
|
||||
|
||||
if do_flip:
|
||||
# Here we have flipped / transposed the map.
|
||||
# In uvs, the y values are decreasing from 1 to 0 and the x
|
||||
# values increase from 0 to 1. We subtract all values from 1
|
||||
# as the x's become y's and the y's become x's.
|
||||
new_uvs = 1.0 - new_uvs[:, [1, 0]]
|
||||
if TYPE_CHECKING:
|
||||
new_uvs = torch.Tensor(new_uvs)
|
||||
|
||||
# If align_corners is True, then an index of x (where x is in
|
||||
# the range 0 .. map_.shape[]-1) in one of the input maps
|
||||
# was hit by a u of x/(map_.shape[]-1).
|
||||
# That x is located at the index loc[] + x in the single_map, and
|
||||
# to hit that we need u to equal (loc[] + x) / (total_size[]-1)
|
||||
# so the old u should be mapped to
|
||||
# { u*(map_.shape[]-1) + loc[] } / (total_size[]-1)
|
||||
|
||||
# If align_corners is False, then an index of x (where x is in
|
||||
# the range 1 .. map_.shape[]-2) in one of the input maps
|
||||
# was hit by a u of (x+0.5)/(map_.shape[]).
|
||||
# That x is located at the index loc[] + 1 + x in the single_map,
|
||||
# (where the 1 is for the border)
|
||||
# and to hit that we need u to equal (loc[] + 1 + x + 0.5) / (total_size[])
|
||||
# so the old u should be mapped to
|
||||
# { loc[] + 1 + u*map_.shape[]-0.5 + 0.5 } / (total_size[])
|
||||
# = { loc[] + 1 + u*map_.shape[] } / (total_size[])
|
||||
|
||||
# We change the y's in new_uvs for the scaling of height,
|
||||
# and the x's for the scaling of width.
|
||||
# That is why the 1's and 0's are mismatched in these lines.
|
||||
one_if_align = 1 if self.align_corners else 0
|
||||
one_if_not_align = 1 - one_if_align
|
||||
denom_x = merging_plan.total_size[0] - one_if_align
|
||||
scale_x = x_shape - one_if_align
|
||||
denom_y = merging_plan.total_size[1] - one_if_align
|
||||
scale_y = y_shape - one_if_align
|
||||
new_uvs[:, 1] *= scale_x / denom_x
|
||||
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
|
||||
new_uvs[:, 0] *= scale_y / denom_y
|
||||
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
|
||||
|
||||
verts_uvs_merged.append(new_uvs)
|
||||
|
||||
faces_uvs_merged = []
|
||||
offset = 0
|
||||
for faces_uvs_, verts_uvs_ in zip(self.faces_uvs_list(), verts_uvs):
|
||||
faces_uvs_merged.append(offset + faces_uvs_)
|
||||
offset += verts_uvs_.shape[0]
|
||||
|
||||
return self.__class__(
|
||||
maps=[single_map],
|
||||
verts_uvs=[torch.cat(verts_uvs_merged)],
|
||||
faces_uvs=[torch.cat(faces_uvs_merged)],
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
|
||||
class TexturesVertex(TexturesBase):
|
||||
def __init__(
|
||||
@@ -1156,3 +1304,9 @@ class TexturesVertex(TexturesBase):
|
||||
new_tex = self.__class__(verts_features=verts_features_list)
|
||||
new_tex._num_verts_per_mesh = num_faces_per_mesh
|
||||
return new_tex
|
||||
|
||||
def join_scene(self) -> "TexturesVertex":
|
||||
"""
|
||||
Return a new TexturesVertex amalgamating the batch.
|
||||
"""
|
||||
return self.__class__(verts_features=[torch.cat(self.verts_features_list())])
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import List, NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
|
||||
@@ -58,3 +60,184 @@ def _interpolate_zbuf(
|
||||
] # (1, H, W, K)
|
||||
zbuf[pix_to_face == -1] = -1
|
||||
return zbuf
|
||||
|
||||
|
||||
# ----------- Rectangle Packing -------------------- #
|
||||
|
||||
# Note the order of members matters here because it determines the queue order.
|
||||
# We want to place longer rectangles first.
|
||||
class _UnplacedRectangle(NamedTuple):
|
||||
size: Tuple[int, int]
|
||||
ind: int
|
||||
flipped: bool
|
||||
|
||||
|
||||
def _try_place_rectangle(
|
||||
rect: _UnplacedRectangle,
|
||||
placed_so_far: List[Tuple[int, int, bool]],
|
||||
occupied: List[Tuple[int, int]],
|
||||
) -> bool:
|
||||
"""
|
||||
Try to place rect within the current bounding box.
|
||||
Part of the implementation of pack_rectangles.
|
||||
|
||||
Note that the arguments `placed_so_far` and `occupied` are modified.
|
||||
|
||||
Args:
|
||||
rect: rectangle to place
|
||||
placed_so_far: the locations decided upon so far - a list of
|
||||
(x, y, whether flipped). The nth element is the
|
||||
location of the nth rectangle if it has been decided.
|
||||
(modified in place)
|
||||
occupied: the nodes of the graph of extents of rightmost placed
|
||||
rectangles - (modified in place)
|
||||
|
||||
Returns:
|
||||
True on success.
|
||||
|
||||
Example:
|
||||
(We always have placed the first rectangle horizontally and other
|
||||
rectangles above it.)
|
||||
Let's say the placed boxes 1-4 are layed out like this.
|
||||
The coordinates of the points marked X are stored in occupied.
|
||||
It is to the right of the X's that we seek to place rect.
|
||||
|
||||
+-----------------------X
|
||||
|2 |
|
||||
| +---X
|
||||
| |4 |
|
||||
| | |
|
||||
| +---+X
|
||||
| |3 |
|
||||
| | |
|
||||
+-----------------------+----+------X
|
||||
y |1 |
|
||||
^ | --->x |
|
||||
| +-----------------------------------+
|
||||
|
||||
We want to place this rectangle.
|
||||
|
||||
+-+
|
||||
|5|
|
||||
| |
|
||||
| | = rect
|
||||
| |
|
||||
| |
|
||||
| |
|
||||
+-+
|
||||
|
||||
The call will succeed, returning True, leaving us with
|
||||
|
||||
+-----------------------X
|
||||
|2 | +-X
|
||||
| +---+|5|
|
||||
| |4 || |
|
||||
| | || |
|
||||
| +---++ |
|
||||
| |3 | |
|
||||
| | | |
|
||||
+-----------------------+----+-+----X
|
||||
|1 |
|
||||
| |
|
||||
+-----------------------------------+ .
|
||||
|
||||
"""
|
||||
total_width = occupied[0][0]
|
||||
needed_height = rect.size[1]
|
||||
current_start_idx = None
|
||||
current_max_width = 0
|
||||
previous_height = 0
|
||||
currently_packed = 0
|
||||
for idx, interval in enumerate(occupied):
|
||||
if interval[0] <= total_width - rect.size[0]:
|
||||
currently_packed += interval[1] - previous_height
|
||||
current_max_width = max(interval[0], current_max_width)
|
||||
if current_start_idx is None:
|
||||
current_start_idx = idx
|
||||
if currently_packed >= needed_height:
|
||||
current_max_width = max(interval[0], current_max_width)
|
||||
placed_so_far[rect.ind] = (
|
||||
current_max_width,
|
||||
occupied[current_start_idx - 1][1],
|
||||
rect.flipped,
|
||||
)
|
||||
new_occupied = (
|
||||
current_max_width + rect.size[0],
|
||||
occupied[current_start_idx - 1][1] + needed_height,
|
||||
)
|
||||
if currently_packed == needed_height:
|
||||
occupied[idx] = new_occupied
|
||||
del occupied[current_start_idx:idx]
|
||||
elif idx > current_start_idx:
|
||||
occupied[idx - 1] = new_occupied
|
||||
del occupied[current_start_idx : (idx - 1)]
|
||||
else:
|
||||
occupied.insert(idx, new_occupied)
|
||||
return True
|
||||
else:
|
||||
current_start_idx = None
|
||||
current_max_width = 0
|
||||
currently_packed = 0
|
||||
previous_height = interval[1]
|
||||
return False
|
||||
|
||||
|
||||
class PackedRectangles(NamedTuple):
|
||||
total_size: Tuple[int, int]
|
||||
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped
|
||||
|
||||
|
||||
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
|
||||
"""
|
||||
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
|
||||
a rectangle by 90 degrees) is allowed.
|
||||
|
||||
This is used to join several uv maps into a single scene, see
|
||||
TexturesUV.join_scene.
|
||||
|
||||
Args:
|
||||
sizes: List of sizes of rectangles to pack
|
||||
|
||||
Returns:
|
||||
total_size: size of total large rectangle
|
||||
rectangles: location for each of the input rectangles
|
||||
"""
|
||||
|
||||
if len(sizes) < 2:
|
||||
raise ValueError("Cannot pack less than two boxes")
|
||||
|
||||
queue = []
|
||||
for i, size in enumerate(sizes):
|
||||
if size[0] < size[1]:
|
||||
queue.append(_UnplacedRectangle((size[1], size[0]), i, True))
|
||||
else:
|
||||
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
|
||||
queue.sort()
|
||||
placed_so_far = [(-1, -1, False)] * len(sizes)
|
||||
|
||||
biggest = queue.pop()
|
||||
total_width, current_height = biggest.size
|
||||
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
|
||||
|
||||
second = queue.pop()
|
||||
placed_so_far[second.ind] = (0, current_height, second.flipped)
|
||||
current_height += second.size[1]
|
||||
occupied = [biggest.size, (second.size[0], current_height)]
|
||||
|
||||
for rect in reversed(queue):
|
||||
if _try_place_rectangle(rect, placed_so_far, occupied):
|
||||
continue
|
||||
|
||||
rotated = _UnplacedRectangle(
|
||||
(rect.size[1], rect.size[0]), rect.ind, not rect.flipped
|
||||
)
|
||||
if _try_place_rectangle(rotated, placed_so_far, occupied):
|
||||
continue
|
||||
|
||||
# rect wasn't placed in the current bounding box,
|
||||
# so we add extra space to fit it in.
|
||||
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
|
||||
current_height += rect.size[1]
|
||||
occupied.append((rect.size[0], current_height))
|
||||
|
||||
return PackedRectangles((total_width, current_height), placed_so_far)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .meshes import Meshes, join_meshes_as_batch
|
||||
from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene
|
||||
from .pointclouds import Pointclouds
|
||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
|
||||
|
||||
|
||||
@@ -1254,7 +1254,7 @@ class Meshes(object):
|
||||
"""
|
||||
verts_packed = self.verts_packed()
|
||||
if vert_offsets_packed.shape != verts_packed.shape:
|
||||
raise ValueError("Verts offsets must have dimension (all_v, 2).")
|
||||
raise ValueError("Verts offsets must have dimension (all_v, 3).")
|
||||
# update verts packed
|
||||
self._verts_packed = verts_packed + vert_offsets_packed
|
||||
new_verts_list = list(
|
||||
@@ -1548,26 +1548,43 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
|
||||
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
|
||||
def join_meshes_as_scene(
|
||||
meshes: Union[Meshes, List[Meshes]], include_textures: bool = True
|
||||
) -> Meshes:
|
||||
"""
|
||||
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
|
||||
objects as a single mesh. If the input is a list, the Meshes objects in the list
|
||||
must all be on the same device. This version ignores all textures in the input meshes.
|
||||
objects as a single mesh. If the input is a list, the Meshes objects in the
|
||||
list must all be on the same device. Unless include_textures is False, the
|
||||
meshes must all have the same type of texture or must all not have textures.
|
||||
|
||||
If textures are included, then the textures are joined as a single scene in
|
||||
addition to the meshes. For this, texture types have an appropriate method
|
||||
called join_scene which joins mesh textures into a single texture.
|
||||
If the textures are TexturesAtlas then they must have the same resolution.
|
||||
If they are TexturesUV then they must have the same align_corners and
|
||||
padding_mode. Values in verts_uvs outside [0, 1] will not
|
||||
be respected.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
|
||||
meshes: Meshes object that contains a batch of meshes, or a list of
|
||||
Meshes objects.
|
||||
include_textures: (bool) whether to try to join the textures.
|
||||
|
||||
Returns:
|
||||
new Meshes object containing a single mesh
|
||||
"""
|
||||
if isinstance(meshes, List):
|
||||
meshes = join_meshes_as_batch(meshes, include_textures=False)
|
||||
meshes = join_meshes_as_batch(meshes, include_textures=include_textures)
|
||||
|
||||
if len(meshes) == 1:
|
||||
return meshes
|
||||
verts = meshes.verts_packed() # (sum(V_n), 3)
|
||||
# Offset automatically done by faces_packed
|
||||
faces = meshes.faces_packed() # (sum(F_n), 3)
|
||||
textures = None
|
||||
|
||||
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
|
||||
if include_textures and meshes.textures is not None:
|
||||
textures = meshes.textures.join_scene()
|
||||
|
||||
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0), textures=textures)
|
||||
return mesh
|
||||
|
||||
Reference in New Issue
Block a user