TexturesUV multiple maps

Summary: Implements the  the TexturesUV with multiple map ids.

Reviewed By: bottler

Differential Revision: D53944063

fbshipit-source-id: 06c25eb6d69f72db0484f16566dd2ca32a560b82
This commit is contained in:
Cijo Jose 2024-03-12 06:59:31 -07:00 committed by Facebook GitHub Bot
parent 7566530669
commit 38cf0dc1c5
2 changed files with 483 additions and 81 deletions

View File

@ -149,6 +149,58 @@ def _pad_texture_maps(
return tex_maps
def _pad_texture_multiple_maps(
multiple_texture_maps: Union[Tuple[torch.Tensor], List[torch.Tensor]],
align_corners: bool,
) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (M_i, H_i, W_i, C)
M_i : Number of texture maps:w
align_corners: used for interpolation
Returns:
tex_maps: Tensor of shape (N, max_M, max_H, max_W, C)
"""
tex_maps = []
max_M = 0
max_H = 0
max_W = 0
C = 0
for im in multiple_texture_maps:
m, h, w, C = im.shape
if m > max_M:
max_M = m
if h > max_H:
max_H = h
if w > max_W:
max_W = w
tex_maps.append(im)
max_shape = (max_M, max_H, max_W, C)
max_im_shape = (max_H, max_W)
for i, tms in enumerate(tex_maps):
new_tex_maps = torch.zeros(max_shape)
for j in range(tms.shape[0]):
im = tms[j]
if im.shape[:2] != max_im_shape:
image_BCHW = im.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW,
size=max_im_shape,
mode="bilinear",
align_corners=align_corners,
)
new_tex_maps[j] = new_image_BCHW[0].permute(1, 2, 0)
else:
new_tex_maps[j] = im
tex_maps[i] = new_tex_maps
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, C)
return tex_maps
# A base class for defining a batch of textures
# with helper methods.
# This is also useful to have so that inside `Meshes`
@ -199,13 +251,20 @@ class TexturesBase:
t = getattr(self, p)
if callable(t):
t = t() # class method
if isinstance(t, list):
if t is None:
new_props[p] = None
elif isinstance(t, list):
if not all(isinstance(elem, (int, float)) for elem in t):
raise ValueError("Extend only supports lists of scalars")
t = [[ti] * N for ti in t]
new_props[p] = list(itertools.chain(*t))
elif torch.is_tensor(t):
new_props[p] = t.repeat_interleave(N, dim=0)
else:
raise ValueError(
f"Property {p} has unsupported type {type(t)}."
"Only tensors and lists are supported."
)
return new_props
def _getitem(self, index: Union[int, slice], props: List[str]):
@ -218,7 +277,7 @@ class TexturesBase:
t = getattr(self, p)
if callable(t):
t = t() # class method
new_props[p] = t[index]
new_props[p] = t[index] if t is not None else None
elif isinstance(index, list):
index = torch.tensor(index)
if isinstance(index, torch.Tensor):
@ -230,8 +289,7 @@ class TexturesBase:
t = getattr(self, p)
if callable(t):
t = t() # class method
new_props[p] = [t[i] for i in index]
new_props[p] = [t[i] for i in index] if t is not None else None
return new_props
def sample_textures(self) -> torch.Tensor:
@ -644,6 +702,10 @@ class TexturesUV(TexturesBase):
maps: Union[torch.Tensor, List[torch.Tensor]],
faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
*,
maps_ids: Optional[
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
] = None,
padding_mode: str = "border",
align_corners: bool = True,
sampling_mode: str = "bilinear",
@ -653,20 +715,33 @@ class TexturesUV(TexturesBase):
vertex in each face. NOTE: this class only supports one texture map per mesh.
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, C)] or a padded tensor of shape (N, H, W, C).
For RGB, C = 3.
maps: Either (1) a texture map per mesh. This can either be a list of maps
[(H, W, C)] or a padded tensor of shape (N, H, W, C).
For RGB, C = 3. In this case maps_ids must be None.
Or (2) a set of M texture maps per mesh. This can either be a list of sets
[(M, H, W, C)] or a padded tensor of shape (N, M, H, W, C).
For RGB, C = 3. In this case maps_ids must be provided to
identify which is relevant to each face.
faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
for each face
for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
(a FloatTensor with values between 0 and 1).
(a FloatTensor with values between 0 and 1).
maps_ids: Used if there are to be multiple maps per face. This can be either a list of map_ids [(F,)]
or a long tensor of shape (N, F) giving the id of the texture map
for each face. If maps_ids is present, the maps has an extra dimension M
(so maps_padded is (N, M, H, W, C) and maps_list has elements of
shape (M, H, W, C)).
Specifically, the color
of a vertex V is given by an average of maps_padded[i, maps_ids[i, f], u, v, :]
over u and v integers adjacent to
_verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
align_corners: If true, the extreme values 0 and 1 for verts_uvs
indicate the centers of the edge pixels in the maps.
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
("zeros", "border" or "reflection").
sampling_mode: type of interpolation used to sample the texture.
Corresponds to the mode parameter in PyTorch's
grid_sample ("nearest" or "bilinear").
Corresponds to the mode parameter in PyTorch's
grid_sample ("nearest" or "bilinear").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
@ -762,6 +837,8 @@ class TexturesUV(TexturesBase):
else:
raise ValueError("Expected verts_uvs to be a tensor or list")
self._maps_ids_padded, self._maps_ids_list = self._format_maps_ids(maps_ids)
if isinstance(maps, (list, tuple)):
self._maps_list = maps
else:
@ -770,14 +847,73 @@ class TexturesUV(TexturesBase):
if self._maps_padded.device != self.device:
raise ValueError("maps must be on the same device as verts/faces uvs.")
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def _format_maps_ids(
self,
maps_ids: Optional[
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
],
) -> Tuple[
Optional[torch.Tensor], Optional[Union[List[torch.Tensor], Tuple[torch.Tensor]]]
]:
if maps_ids is None:
return None, None
elif isinstance(maps_ids, (list, tuple)):
for mid in maps_ids:
if mid.ndim != 1:
msg = "Expected maps_ids to be of shape (F,); got %r"
raise ValueError(msg % repr(mid.shape))
if len(maps_ids) != self._N:
raise ValueError(
"map_ids, faces_uvs and verts_uvs must have the same batch dimension"
)
if not all(mid.device == self.device for mid in maps_ids):
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
if not all(
mid.shape[0] == nfm
for mid, nfm in zip(maps_ids, self._num_faces_per_mesh)
):
raise ValueError(
"map_ids and faces_uvs must have the same number of faces per mesh"
)
if not all(mid.device == self.device for mid in maps_ids):
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
if not self._num_faces_per_mesh:
return torch.Tensor(), maps_ids
return list_to_padded(maps_ids, pad_value=0), maps_ids
elif isinstance(maps_ids, torch.Tensor):
if maps_ids.ndim != 2 or maps_ids.shape[0] != self._N:
msg = "Expected maps_ids to be of shape (N, F); got %r"
raise ValueError(msg % repr(maps_ids.shape))
maps_ids_padded = maps_ids
max_F = max(self._num_faces_per_mesh)
if not maps_ids.shape[1] == max_F:
raise ValueError(
"map_ids and faces_uvs must have the same number of faces per mesh"
)
if maps_ids.device != self.device:
raise ValueError(
"maps_ids and verts/faces uvs must be on the same device"
)
return maps_ids_padded, None
raise ValueError("Expected maps_ids to be a tensor or list")
def _format_maps_padded(
self, maps: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
maps_ids_none = self._maps_ids_padded is None
if isinstance(maps, torch.Tensor):
if maps.ndim != 4 or maps.shape[0] != self._N:
if not maps_ids_none:
if maps.ndim != 5 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, M, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
elif maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
return maps
@ -786,15 +922,27 @@ class TexturesUV(TexturesBase):
if len(maps) != self._N:
raise ValueError("Expected one texture map per mesh in the batch.")
if self._N > 0:
if not all(map.ndim == 3 for map in maps):
ndim = 3 if maps_ids_none else 4
if not all(map.ndim == ndim for map in maps):
raise ValueError("Invalid number of dimensions in texture maps")
if not all(map.shape[2] == maps[0].shape[2] for map in maps):
if not all(map.shape[-1] == maps[0].shape[-1] for map in maps):
raise ValueError("Inconsistent number of channels in maps")
maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners)
else:
maps_padded = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
maps_padded = (
_pad_texture_maps(maps, align_corners=self.align_corners)
if maps_ids_none
else _pad_texture_multiple_maps(
maps, align_corners=self.align_corners
)
)
else:
if maps_ids_none:
maps_padded = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
)
else:
maps_padded = torch.empty(
(self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device
)
return maps_padded
raise ValueError("Expected maps to be a tensor or list of tensors.")
@ -804,6 +952,11 @@ class TexturesUV(TexturesBase):
self.maps_padded().clone(),
self.faces_uvs_padded().clone(),
self.verts_uvs_padded().clone(),
maps_ids=(
self._maps_ids_padded.clone()
if self._maps_ids_padded is not None
else None
),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
@ -814,6 +967,8 @@ class TexturesUV(TexturesBase):
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
if self._faces_uvs_list is not None:
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
if self._maps_ids_list is not None:
tex._maps_ids_list = [f.clone() for f in self._maps_ids_list]
num_faces = (
self._num_faces_per_mesh.clone()
if torch.is_tensor(self._num_faces_per_mesh)
@ -828,6 +983,11 @@ class TexturesUV(TexturesBase):
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
self.verts_uvs_padded().detach(),
maps_ids=(
self._maps_ids_padded.detach()
if self._maps_ids_padded is not None
else None
),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
@ -838,6 +998,8 @@ class TexturesUV(TexturesBase):
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
if self._faces_uvs_list is not None:
tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
if self._maps_ids_list is not None:
tex._maps_ids_list = [mi.detach() for mi in self._maps_ids_list]
num_faces = (
self._num_faces_per_mesh.detach()
if torch.is_tensor(self._num_faces_per_mesh)
@ -848,27 +1010,44 @@ class TexturesUV(TexturesBase):
return tex
def __getitem__(self, index) -> "TexturesUV":
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
props = [
"faces_uvs_list",
"verts_uvs_list",
"maps_list",
"maps_ids_list",
"_num_faces_per_mesh",
]
new_props = self._getitem(index, props)
faces_uvs = new_props["faces_uvs_list"]
verts_uvs = new_props["verts_uvs_list"]
maps = new_props["maps_list"]
maps_ids = new_props["maps_ids_list"]
# if index has multiple values then faces/verts/maps may be a list of tensors
if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
if maps_ids is not None and not isinstance(maps_ids, (list, tuple)):
raise ValueError(
"Maps ids are not in the correct format expected list or tuple"
)
new_tex = self.__class__(
faces_uvs=faces_uvs,
verts_uvs=verts_uvs,
maps=maps,
maps_ids=maps_ids,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
if maps_ids is not None and not torch.is_tensor(maps_ids):
raise ValueError(
"Maps ids are not in the correct format expected tensor"
)
new_tex = self.__class__(
faces_uvs=[faces_uvs],
verts_uvs=[verts_uvs],
maps=[maps],
maps_ids=[maps_ids] if maps_ids is not None else None,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
@ -927,6 +1106,17 @@ class TexturesUV(TexturesBase):
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
return self._verts_uvs_list
def maps_ids_padded(self) -> Optional[torch.Tensor]:
return self._maps_ids_padded
def maps_ids_list(self) -> Optional[List[torch.Tensor]]:
if self._maps_ids_list is not None:
return self._maps_ids_list
elif self._maps_ids_padded is not None:
return self._maps_ids_padded.unbind(0)
else:
return None
# Currently only the padded maps are used.
def maps_padded(self) -> torch.Tensor:
return self._maps_padded
@ -943,6 +1133,7 @@ class TexturesUV(TexturesBase):
"maps_padded",
"verts_uvs_padded",
"faces_uvs_padded",
"maps_ids_padded",
"_num_faces_per_mesh",
],
)
@ -950,6 +1141,7 @@ class TexturesUV(TexturesBase):
maps=new_props["maps_padded"],
faces_uvs=new_props["faces_uvs_padded"],
verts_uvs=new_props["verts_uvs_padded"],
maps_ids=new_props["maps_ids_padded"],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
@ -992,7 +1184,6 @@ class TexturesUV(TexturesBase):
i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
]
faces_verts_uvs = torch.cat(packing_list)
texture_maps = self.maps_padded()
# pixel_uvs: (N, H, W, K, 2)
pixel_uvs = interpolate_face_attributes(
@ -1000,49 +1191,91 @@ class TexturesUV(TexturesBase):
)
N, H_out, W_out, K = fragments.pix_to_face.shape
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
# pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
texture_maps = self.maps_padded()
maps_ids_padded = self.maps_ids_padded()
if maps_ids_padded is None:
# pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
N, H_in, W_in, C = texture_maps.shape # 3 for RGB
# textures.map:
# (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
# -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
texture_maps = (
texture_maps.permute(0, 3, 1, 2)[None, ...]
.expand(K, -1, -1, -1, -1)
.transpose(0, 1)
.reshape(N * K, C, H_in, W_in)
)
# textures.map:
# (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
# -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
texture_maps = (
texture_maps.permute(0, 3, 1, 2)[None, ...]
.expand(K, -1, -1, -1, -1)
.transpose(0, 1)
.reshape(N * K, C, H_in, W_in)
)
# Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
# Now need to format the pixel uvs and the texture map correctly!
# From pytorch docs, grid_sample takes `grid` and `input`:
# grid specifies the sampling pixel locations normalized by
# the input spatial dimensions It should have most
# values in the range of [-1, 1]. Values x = -1, y = -1
# is the left-top pixel of input, and values x = 1, y = 1 is the
# right-bottom pixel of input.
# Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
# Now need to format the pixel uvs and the texture map correctly!
# From pytorch docs, grid_sample takes `grid` and `input`:
# grid specifies the sampling pixel locations normalized by
# the input spatial dimensions It should have most
# values in the range of [-1, 1]. Values x = -1, y = -1
# is the left-top pixel of input, and values x = 1, y = 1 is the
# right-bottom pixel of input.
# map to a range of [-1, 1] and flip the y axis
pixel_uvs = torch.lerp(
pixel_uvs.new_tensor([-1.0, 1.0]),
pixel_uvs.new_tensor([1.0, -1.0]),
pixel_uvs,
)
# map to a range of [-1, 1] and flip the y axis
pixel_uvs = torch.lerp(
pixel_uvs.new_tensor([-1.0, 1.0]),
pixel_uvs.new_tensor([1.0, -1.0]),
pixel_uvs,
)
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(
texture_maps,
pixel_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
# texels now has shape (NK, C, H_out, W_out)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
else:
# We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),fragmenmts.pix_to_face: (N, Ho, Wo, K)
# Get pixel_to_map_ids: (N, K, Ho, Wo) by indexing pix_to_face into maps_ids
N, M, H_in, W_in, C = texture_maps.shape # 3 for RGB
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(
texture_maps,
pixel_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
# texels now has shape (NK, C, H_out, W_out)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
mask = fragments.pix_to_face < 0
pix_to_face = fragments.pix_to_face.clone()
pix_to_face[mask] = 0
pixel_to_map_ids = (
maps_ids_padded.flatten()
.gather(0, pix_to_face.flatten())
.view(N, K, H_out, W_out)
)
# Normalize between -1 and 1 with M (number of maps)
pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4)
pixel_uvs = torch.lerp(
pixel_uvs.new_tensor([-1.0, 1.0]),
pixel_uvs.new_tensor([1.0, -1.0]),
pixel_uvs,
)
# N x H_out x W_out x K x 3
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4)
# (N, M, H, W, C) -> (N, C, M, H, W)
texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(
texture_maps,
pixel_uvms,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
# (N, C, K, H_out, W_out) -> (N, H_out, W_out, K, C)
texels = texels.permute(0, 3, 4, 2, 1).contiguous()
return texels
def faces_verts_textures_packed(self) -> torch.Tensor:
"""
@ -1065,25 +1298,41 @@ class TexturesUV(TexturesBase):
faces_verts_uvs = _list_to_padded_wrapper(
packing_list, pad_value=0.0
) # Nxmax(Fi)x3x2
texture_maps = self.maps_padded() # NxHxWxC
texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW
# map to a range of [-1, 1] and flip the y axis
faces_verts_uvs = torch.lerp(
faces_verts_uvs.new_tensor([-1.0, 1.0]),
faces_verts_uvs.new_tensor([1.0, -1.0]),
faces_verts_uvs,
)
texture_maps = self.maps_padded() # NxHxWxC or NxMxHxWxC
maps_ids_padded = self.maps_ids_padded()
if maps_ids_padded is None:
texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW
else:
M = texture_maps.shape[1]
# (N, M, H, W, C) -> (N, C, M, H, W)
texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
# expand maps_ids to (N, F, 3, 1)
maps_ids_padded = maps_ids_padded[:, :, None, None].expand(-1, -1, 3, -1)
maps_ids_padded = (2.0 * maps_ids_padded.float() / float(M - 1)) - 1.0
# (N, F, 3, 2+1) -> (N, 1, F, 3, 3)
faces_verts_uvs = torch.cat(
(faces_verts_uvs, maps_ids_padded), dim=3
).unsqueeze(1)
# (N, M, H, W, C) -> (N, C, H, W, M)
# texture_maps = texture_maps.permute(0, 4, 2, 3, 1)
textures = F.grid_sample(
texture_maps,
faces_verts_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
) # NxCxmax(Fi)x3
textures = textures.permute(0, 2, 3, 1) # Nxmax(Fi)x3xC
) # (N, C, max(Fi), 3)
if maps_ids_padded is not None:
textures = textures.squeeze(dim=2)
# (N, C, max(Fi), 3) -> (N, max(Fi), 3, C)
textures = textures.permute(0, 2, 3, 1)
textures = _padded_to_list_wrapper(
textures, split_size=self._num_faces_per_mesh
) # list of N {Fix3xC} tensors
@ -1102,6 +1351,11 @@ class TexturesUV(TexturesBase):
new_tex: TexturesUV object with the combined
textures from self and the list `textures`.
"""
if self.maps_ids_padded() is not None:
# TODO
raise NotImplementedError(
"join_batch does not support TexturesUV with multiple maps"
)
tex_types_same = all(isinstance(tex, TexturesUV) for tex in textures)
if not tex_types_same:
raise ValueError("All textures must be of type TexturesUV.")
@ -1137,8 +1391,8 @@ class TexturesUV(TexturesBase):
new_tex = self.__class__(
maps=maps_list,
verts_uvs=verts_uvs_list,
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
@ -1205,6 +1459,9 @@ class TexturesUV(TexturesBase):
_place_map_into_single_map is used to copy the maps into the single map.
The merging of verts_uvs and faces_uvs is handled locally in this function.
"""
if self.maps_ids_padded() is not None:
# TODO
raise NotImplementedError("join_scene does not support multiple maps.")
maps = self.maps_list()
heights_and_widths = []
extra_border = 0 if self.align_corners else 2
@ -1305,8 +1562,8 @@ class TexturesUV(TexturesBase):
return self.__class__(
maps=[single_map],
verts_uvs=[torch.cat(verts_uvs_merged)],
faces_uvs=[torch.cat(faces_uvs_merged)],
verts_uvs=[torch.cat(verts_uvs_merged)],
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
@ -1326,6 +1583,9 @@ class TexturesUV(TexturesBase):
centers: coordinates of points in the texture image
- a FloatTensor of shape (V,2)
"""
if self.maps_ids_padded() is not None:
# TODO: invent a visualization for the multiple maps case
raise NotImplementedError("This function does not support multiple maps.")
if self._N != 1:
raise ValueError(
"This function only supports plotting textures for one mesh."
@ -1388,7 +1648,9 @@ class TexturesUV(TexturesBase):
A "TexturesUV in which faces_uvs_padded, verts_uvs_padded, and maps_padded
have length sum(len(faces) for faces in faces_ids_list)
"""
if self.maps_ids_padded() is not None:
# TODO
raise NotImplementedError("This function does not support multiple maps.")
if len(faces_ids_list) != len(self.faces_uvs_padded()):
raise IndexError(
"faces_uvs_padded must be of " "the same length as face_ids_list."
@ -1407,12 +1669,12 @@ class TexturesUV(TexturesBase):
sub_maps.append(map_)
return self.__class__(
sub_maps,
sub_faces_uvs,
sub_verts_uvs,
self.padding_mode,
self.align_corners,
self.sampling_mode,
maps=sub_maps,
faces_uvs=sub_faces_uvs,
verts_uvs=sub_verts_uvs,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)

View File

@ -718,6 +718,22 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
verts_uvs=torch.rand(size=(5, 15, 2)),
)
# maps ids are not none but maps doesn't have multiple map indices
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 1, (5, 10), dtype=torch.long),
)
# maps ids is none but maps have multiple map indices
with self.assertRaisesRegex(ValueError, "map"):
TexturesUV(
maps=torch.ones((5, 2, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
def test_faces_verts_textures(self):
device = torch.device("cuda:0")
N, V, F, H, W = 2, 5, 12, 8, 8
@ -755,6 +771,47 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_faces_verts_multiple_map_textures(self):
device = torch.device("cuda:0")
N, M, V, F, H, W = 2, 3, 5, 12, 8, 8
vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
face_uvs = torch.randint(
high=V, size=(N, F, 3), dtype=torch.int64, device=device
)
map_ids = torch.randint(0, M, (N, F), device=device)
maps = torch.rand((N, M, H, W, 3), dtype=torch.float32, device=device)
tex = TexturesUV(
maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs, maps_ids=map_ids
)
# naive faces_verts_textures
faces_verts_texs = []
for n in range(N):
temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
for f in range(F):
uv0 = vert_uvs[n, face_uvs[n, f, 0]]
uv1 = vert_uvs[n, face_uvs[n, f, 1]]
uv2 = vert_uvs[n, face_uvs[n, f, 2]]
map_id = map_ids[n, f]
idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2
idx = idx * 2.0 - 1.0
imap = maps[n, map_id].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW
imap = torch.flip(imap, [2])
texts = torch.nn.functional.grid_sample(
imap,
idx,
align_corners=tex.align_corners,
padding_mode=tex.padding_mode,
) # 1x3x1x3
temp[f] = texts[0, :, 0, :].permute(1, 0)
faces_verts_texs.append(temp)
faces_verts_texs = torch.cat(faces_verts_texs, 0)
self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
def test_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
@ -781,6 +838,37 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_multiple_maps_clone(self):
tex = TexturesUV(
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
self.assertSeparate(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
self.assertClose(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
for i in range(tex._N):
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertSeparate(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
self.assertClose(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
def test_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
@ -805,6 +893,35 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_multiple_maps_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 3, 16, 16, 3), requires_grad=True),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._maps_padded.requires_grad)
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
self.assertFalse(tex_detached._maps_ids_padded.requires_grad)
self.assertClose(tex._maps_ids_padded, tex_detached._maps_ids_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
self.assertFalse(tex_detached.maps_ids_list()[i].requires_grad)
self.assertClose(tex.maps_ids_list()[i], tex_detached.maps_ids_list()[i])
def test_extend(self):
B = 5
mesh = init_mesh(B, 30, 50)
@ -878,13 +995,15 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
torch.tensor([[0, 1, 2], [3, 4, 5]]),
] # (N, 3, 3)
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
maps_ids_given_list = [torch.randint(0, 3, (3,)), torch.randint(0, 3, (2,))]
num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
tex = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
maps=torch.ones((N, 3, 16, 16, 3)),
faces_uvs=faces_uvs_list,
verts_uvs=verts_uvs_list,
maps_ids=maps_ids_given_list,
)
# This is set inside Meshes when textures is passed as an input.
@ -898,24 +1017,33 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
faces_list = tex1.faces_uvs_list()
faces_padded = tex1.faces_uvs_padded()
maps_ids_list = tex1.maps_ids_list()
maps_ids_padded = tex1.maps_ids_padded()
for f1, f2 in zip(faces_list, faces_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(verts_list, verts_uvs_list):
self.assertTrue((f1 == f2).all().item())
for f1, f2 in zip(maps_ids_given_list, maps_ids_list):
self.assertTrue((f1 == f2).all().item())
self.assertTrue(faces_padded.shape == (2, 3, 3))
self.assertTrue(verts_padded.shape == (2, 9, 2))
self.assertTrue(maps_ids_padded.shape == (2, 3))
# Case where num_faces_per_mesh is not set and faces_verts_uvs
# are initialized with a padded tensor.
tex2 = TexturesUV(
maps=torch.ones((N, 16, 16, 3)),
maps=torch.ones((N, 3, 16, 16, 3)),
verts_uvs=verts_padded,
faces_uvs=faces_padded,
maps_ids=maps_ids_padded,
)
faces_list = tex2.faces_uvs_list()
verts_list = tex2.verts_uvs_list()
maps_ids_list = tex2.maps_ids_list()
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
n = num_faces_per_mesh[i]
@ -925,23 +1053,30 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
n = num_verts_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
for i, (f1, f2) in enumerate(zip(maps_ids_list, maps_ids_given_list)):
n = num_faces_per_mesh[i]
self.assertTrue((f1[:n] == f2).all().item())
def test_to(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
device = torch.device("cuda:0")
tex = tex.to(device)
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex._maps_ids_padded.device, device)
def test_mesh_to(self):
tex_cpu = TexturesUV(
maps=torch.ones((5, 16, 16, 3)),
maps=torch.ones((5, 3, 16, 16, 3)),
faces_uvs=torch.randint(size=(5, 10, 3), high=15),
verts_uvs=torch.rand(size=(5, 15, 2)),
maps_ids=torch.randint(0, 3, (5, 10)),
)
verts = torch.rand(size=(5, 15, 3))
faces = torch.randint(size=(5, 10, 3), high=15)
@ -952,24 +1087,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
self.assertEqual(tex._faces_uvs_padded.device, device)
self.assertEqual(tex._verts_uvs_padded.device, device)
self.assertEqual(tex._maps_padded.device, device)
self.assertEqual(tex._maps_ids_padded.device, device)
self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
self.assertEqual(tex_cpu._maps_ids_padded.device, cpu)
self.assertEqual(tex_cpu.device, cpu)
self.assertEqual(tex.device, device)
def test_getitem(self):
N = 5
M = 3
V = 20
F = 10
source = {
"maps": torch.rand(size=(N, 1, 1, 3)),
"maps": torch.rand(size=(N, M, 1, 1, 3)),
"faces_uvs": torch.randint(size=(N, F, 3), high=V),
"verts_uvs": torch.randn(size=(N, V, 2)),
"maps_ids": torch.randint(0, M, (N, F)),
}
tex = TexturesUV(
maps=source["maps"],
faces_uvs=source["faces_uvs"],
verts_uvs=source["verts_uvs"],
maps_ids=source["maps_ids"],
)
verts = torch.rand(size=(N, V, 3))