diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index ed662152..736be41d 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -149,6 +149,58 @@ def _pad_texture_maps( return tex_maps +def _pad_texture_multiple_maps( + multiple_texture_maps: Union[Tuple[torch.Tensor], List[torch.Tensor]], + align_corners: bool, +) -> torch.Tensor: + """ + Pad all texture images so they have the same height and width. + + Args: + images: list of N tensors of shape (M_i, H_i, W_i, C) + M_i : Number of texture maps:w + + align_corners: used for interpolation + + Returns: + tex_maps: Tensor of shape (N, max_M, max_H, max_W, C) + """ + tex_maps = [] + max_M = 0 + max_H = 0 + max_W = 0 + C = 0 + for im in multiple_texture_maps: + m, h, w, C = im.shape + if m > max_M: + max_M = m + if h > max_H: + max_H = h + if w > max_W: + max_W = w + tex_maps.append(im) + max_shape = (max_M, max_H, max_W, C) + max_im_shape = (max_H, max_W) + for i, tms in enumerate(tex_maps): + new_tex_maps = torch.zeros(max_shape) + for j in range(tms.shape[0]): + im = tms[j] + if im.shape[:2] != max_im_shape: + image_BCHW = im.permute(2, 0, 1)[None] + new_image_BCHW = interpolate( + image_BCHW, + size=max_im_shape, + mode="bilinear", + align_corners=align_corners, + ) + new_tex_maps[j] = new_image_BCHW[0].permute(1, 2, 0) + else: + new_tex_maps[j] = im + tex_maps[i] = new_tex_maps + tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, C) + return tex_maps + + # A base class for defining a batch of textures # with helper methods. # This is also useful to have so that inside `Meshes` @@ -199,13 +251,20 @@ class TexturesBase: t = getattr(self, p) if callable(t): t = t() # class method - if isinstance(t, list): + if t is None: + new_props[p] = None + elif isinstance(t, list): if not all(isinstance(elem, (int, float)) for elem in t): raise ValueError("Extend only supports lists of scalars") t = [[ti] * N for ti in t] new_props[p] = list(itertools.chain(*t)) elif torch.is_tensor(t): new_props[p] = t.repeat_interleave(N, dim=0) + else: + raise ValueError( + f"Property {p} has unsupported type {type(t)}." + "Only tensors and lists are supported." + ) return new_props def _getitem(self, index: Union[int, slice], props: List[str]): @@ -218,7 +277,7 @@ class TexturesBase: t = getattr(self, p) if callable(t): t = t() # class method - new_props[p] = t[index] + new_props[p] = t[index] if t is not None else None elif isinstance(index, list): index = torch.tensor(index) if isinstance(index, torch.Tensor): @@ -230,8 +289,7 @@ class TexturesBase: t = getattr(self, p) if callable(t): t = t() # class method - new_props[p] = [t[i] for i in index] - + new_props[p] = [t[i] for i in index] if t is not None else None return new_props def sample_textures(self) -> torch.Tensor: @@ -644,6 +702,10 @@ class TexturesUV(TexturesBase): maps: Union[torch.Tensor, List[torch.Tensor]], faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + *, + maps_ids: Optional[ + Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] + ] = None, padding_mode: str = "border", align_corners: bool = True, sampling_mode: str = "bilinear", @@ -653,20 +715,33 @@ class TexturesUV(TexturesBase): vertex in each face. NOTE: this class only supports one texture map per mesh. Args: - maps: texture map per mesh. This can either be a list of maps - [(H, W, C)] or a padded tensor of shape (N, H, W, C). - For RGB, C = 3. + maps: Either (1) a texture map per mesh. This can either be a list of maps + [(H, W, C)] or a padded tensor of shape (N, H, W, C). + For RGB, C = 3. In this case maps_ids must be None. + Or (2) a set of M texture maps per mesh. This can either be a list of sets + [(M, H, W, C)] or a padded tensor of shape (N, M, H, W, C). + For RGB, C = 3. In this case maps_ids must be provided to + identify which is relevant to each face. faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs - for each face + for each face verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex - (a FloatTensor with values between 0 and 1). + (a FloatTensor with values between 0 and 1). + maps_ids: Used if there are to be multiple maps per face. This can be either a list of map_ids [(F,)] + or a long tensor of shape (N, F) giving the id of the texture map + for each face. If maps_ids is present, the maps has an extra dimension M + (so maps_padded is (N, M, H, W, C) and maps_list has elements of + shape (M, H, W, C)). + Specifically, the color + of a vertex V is given by an average of maps_padded[i, maps_ids[i, f], u, v, :] + over u and v integers adjacent to + _verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] . align_corners: If true, the extreme values 0 and 1 for verts_uvs - indicate the centers of the edge pixels in the maps. + indicate the centers of the edge pixels in the maps. padding_mode: padding mode for outside grid values - ("zeros", "border" or "reflection"). + ("zeros", "border" or "reflection"). sampling_mode: type of interpolation used to sample the texture. - Corresponds to the mode parameter in PyTorch's - grid_sample ("nearest" or "bilinear"). + Corresponds to the mode parameter in PyTorch's + grid_sample ("nearest" or "bilinear"). The align_corners and padding_mode arguments correspond to the arguments of the `grid_sample` torch function. There is an informative illustration of @@ -762,6 +837,8 @@ class TexturesUV(TexturesBase): else: raise ValueError("Expected verts_uvs to be a tensor or list") + self._maps_ids_padded, self._maps_ids_list = self._format_maps_ids(maps_ids) + if isinstance(maps, (list, tuple)): self._maps_list = maps else: @@ -770,14 +847,73 @@ class TexturesUV(TexturesBase): if self._maps_padded.device != self.device: raise ValueError("maps must be on the same device as verts/faces uvs.") - self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) + def _format_maps_ids( + self, + maps_ids: Optional[ + Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] + ], + ) -> Tuple[ + Optional[torch.Tensor], Optional[Union[List[torch.Tensor], Tuple[torch.Tensor]]] + ]: + if maps_ids is None: + return None, None + elif isinstance(maps_ids, (list, tuple)): + for mid in maps_ids: + if mid.ndim != 1: + msg = "Expected maps_ids to be of shape (F,); got %r" + raise ValueError(msg % repr(mid.shape)) + if len(maps_ids) != self._N: + raise ValueError( + "map_ids, faces_uvs and verts_uvs must have the same batch dimension" + ) + if not all(mid.device == self.device for mid in maps_ids): + raise ValueError( + "maps_ids and verts/faces uvs must be on the same device" + ) + + if not all( + mid.shape[0] == nfm + for mid, nfm in zip(maps_ids, self._num_faces_per_mesh) + ): + raise ValueError( + "map_ids and faces_uvs must have the same number of faces per mesh" + ) + if not all(mid.device == self.device for mid in maps_ids): + raise ValueError( + "maps_ids and verts/faces uvs must be on the same device" + ) + if not self._num_faces_per_mesh: + return torch.Tensor(), maps_ids + return list_to_padded(maps_ids, pad_value=0), maps_ids + elif isinstance(maps_ids, torch.Tensor): + if maps_ids.ndim != 2 or maps_ids.shape[0] != self._N: + msg = "Expected maps_ids to be of shape (N, F); got %r" + raise ValueError(msg % repr(maps_ids.shape)) + maps_ids_padded = maps_ids + max_F = max(self._num_faces_per_mesh) + if not maps_ids.shape[1] == max_F: + raise ValueError( + "map_ids and faces_uvs must have the same number of faces per mesh" + ) + if maps_ids.device != self.device: + raise ValueError( + "maps_ids and verts/faces uvs must be on the same device" + ) + return maps_ids_padded, None + raise ValueError("Expected maps_ids to be a tensor or list") + def _format_maps_padded( self, maps: Union[torch.Tensor, List[torch.Tensor]] ) -> torch.Tensor: + maps_ids_none = self._maps_ids_padded is None if isinstance(maps, torch.Tensor): - if maps.ndim != 4 or maps.shape[0] != self._N: + if not maps_ids_none: + if maps.ndim != 5 or maps.shape[0] != self._N: + msg = "Expected maps to be of shape (N, M, H, W, C); got %r" + raise ValueError(msg % repr(maps.shape)) + elif maps.ndim != 4 or maps.shape[0] != self._N: msg = "Expected maps to be of shape (N, H, W, C); got %r" raise ValueError(msg % repr(maps.shape)) return maps @@ -786,15 +922,27 @@ class TexturesUV(TexturesBase): if len(maps) != self._N: raise ValueError("Expected one texture map per mesh in the batch.") if self._N > 0: - if not all(map.ndim == 3 for map in maps): + ndim = 3 if maps_ids_none else 4 + if not all(map.ndim == ndim for map in maps): raise ValueError("Invalid number of dimensions in texture maps") - if not all(map.shape[2] == maps[0].shape[2] for map in maps): + if not all(map.shape[-1] == maps[0].shape[-1] for map in maps): raise ValueError("Inconsistent number of channels in maps") - maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners) - else: - maps_padded = torch.empty( - (self._N, 0, 0, 3), dtype=torch.float32, device=self.device + maps_padded = ( + _pad_texture_maps(maps, align_corners=self.align_corners) + if maps_ids_none + else _pad_texture_multiple_maps( + maps, align_corners=self.align_corners + ) ) + else: + if maps_ids_none: + maps_padded = torch.empty( + (self._N, 0, 0, 3), dtype=torch.float32, device=self.device + ) + else: + maps_padded = torch.empty( + (self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device + ) return maps_padded raise ValueError("Expected maps to be a tensor or list of tensors.") @@ -804,6 +952,11 @@ class TexturesUV(TexturesBase): self.maps_padded().clone(), self.faces_uvs_padded().clone(), self.verts_uvs_padded().clone(), + maps_ids=( + self._maps_ids_padded.clone() + if self._maps_ids_padded is not None + else None + ), align_corners=self.align_corners, padding_mode=self.padding_mode, sampling_mode=self.sampling_mode, @@ -814,6 +967,8 @@ class TexturesUV(TexturesBase): tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list] if self._faces_uvs_list is not None: tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list] + if self._maps_ids_list is not None: + tex._maps_ids_list = [f.clone() for f in self._maps_ids_list] num_faces = ( self._num_faces_per_mesh.clone() if torch.is_tensor(self._num_faces_per_mesh) @@ -828,6 +983,11 @@ class TexturesUV(TexturesBase): self.maps_padded().detach(), self.faces_uvs_padded().detach(), self.verts_uvs_padded().detach(), + maps_ids=( + self._maps_ids_padded.detach() + if self._maps_ids_padded is not None + else None + ), align_corners=self.align_corners, padding_mode=self.padding_mode, sampling_mode=self.sampling_mode, @@ -838,6 +998,8 @@ class TexturesUV(TexturesBase): tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list] if self._faces_uvs_list is not None: tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list] + if self._maps_ids_list is not None: + tex._maps_ids_list = [mi.detach() for mi in self._maps_ids_list] num_faces = ( self._num_faces_per_mesh.detach() if torch.is_tensor(self._num_faces_per_mesh) @@ -848,27 +1010,44 @@ class TexturesUV(TexturesBase): return tex def __getitem__(self, index) -> "TexturesUV": - props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"] + props = [ + "faces_uvs_list", + "verts_uvs_list", + "maps_list", + "maps_ids_list", + "_num_faces_per_mesh", + ] new_props = self._getitem(index, props) faces_uvs = new_props["faces_uvs_list"] verts_uvs = new_props["verts_uvs_list"] maps = new_props["maps_list"] + maps_ids = new_props["maps_ids_list"] # if index has multiple values then faces/verts/maps may be a list of tensors if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]): + if maps_ids is not None and not isinstance(maps_ids, (list, tuple)): + raise ValueError( + "Maps ids are not in the correct format expected list or tuple" + ) new_tex = self.__class__( faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps, + maps_ids=maps_ids, padding_mode=self.padding_mode, align_corners=self.align_corners, sampling_mode=self.sampling_mode, ) elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]): + if maps_ids is not None and not torch.is_tensor(maps_ids): + raise ValueError( + "Maps ids are not in the correct format expected tensor" + ) new_tex = self.__class__( faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps], + maps_ids=[maps_ids] if maps_ids is not None else None, padding_mode=self.padding_mode, align_corners=self.align_corners, sampling_mode=self.sampling_mode, @@ -927,6 +1106,17 @@ class TexturesUV(TexturesBase): self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0)) return self._verts_uvs_list + def maps_ids_padded(self) -> Optional[torch.Tensor]: + return self._maps_ids_padded + + def maps_ids_list(self) -> Optional[List[torch.Tensor]]: + if self._maps_ids_list is not None: + return self._maps_ids_list + elif self._maps_ids_padded is not None: + return self._maps_ids_padded.unbind(0) + else: + return None + # Currently only the padded maps are used. def maps_padded(self) -> torch.Tensor: return self._maps_padded @@ -943,6 +1133,7 @@ class TexturesUV(TexturesBase): "maps_padded", "verts_uvs_padded", "faces_uvs_padded", + "maps_ids_padded", "_num_faces_per_mesh", ], ) @@ -950,6 +1141,7 @@ class TexturesUV(TexturesBase): maps=new_props["maps_padded"], faces_uvs=new_props["faces_uvs_padded"], verts_uvs=new_props["verts_uvs_padded"], + maps_ids=new_props["maps_ids_padded"], padding_mode=self.padding_mode, align_corners=self.align_corners, sampling_mode=self.sampling_mode, @@ -992,7 +1184,6 @@ class TexturesUV(TexturesBase): i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) ] faces_verts_uvs = torch.cat(packing_list) - texture_maps = self.maps_padded() # pixel_uvs: (N, H, W, K, 2) pixel_uvs = interpolate_face_attributes( @@ -1000,49 +1191,91 @@ class TexturesUV(TexturesBase): ) N, H_out, W_out, K = fragments.pix_to_face.shape - N, H_in, W_in, C = texture_maps.shape # 3 for RGB - # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) - pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) + texture_maps = self.maps_padded() + maps_ids_padded = self.maps_ids_padded() + if maps_ids_padded is None: + # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) + pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) + N, H_in, W_in, C = texture_maps.shape # 3 for RGB - # textures.map: - # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) - # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) - texture_maps = ( - texture_maps.permute(0, 3, 1, 2)[None, ...] - .expand(K, -1, -1, -1, -1) - .transpose(0, 1) - .reshape(N * K, C, H_in, W_in) - ) + # textures.map: + # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) + # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) + texture_maps = ( + texture_maps.permute(0, 3, 1, 2)[None, ...] + .expand(K, -1, -1, -1, -1) + .transpose(0, 1) + .reshape(N * K, C, H_in, W_in) + ) + # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) + # Now need to format the pixel uvs and the texture map correctly! + # From pytorch docs, grid_sample takes `grid` and `input`: + # grid specifies the sampling pixel locations normalized by + # the input spatial dimensions It should have most + # values in the range of [-1, 1]. Values x = -1, y = -1 + # is the left-top pixel of input, and values x = 1, y = 1 is the + # right-bottom pixel of input. - # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) - # Now need to format the pixel uvs and the texture map correctly! - # From pytorch docs, grid_sample takes `grid` and `input`: - # grid specifies the sampling pixel locations normalized by - # the input spatial dimensions It should have most - # values in the range of [-1, 1]. Values x = -1, y = -1 - # is the left-top pixel of input, and values x = 1, y = 1 is the - # right-bottom pixel of input. + # map to a range of [-1, 1] and flip the y axis + pixel_uvs = torch.lerp( + pixel_uvs.new_tensor([-1.0, 1.0]), + pixel_uvs.new_tensor([1.0, -1.0]), + pixel_uvs, + ) - # map to a range of [-1, 1] and flip the y axis - pixel_uvs = torch.lerp( - pixel_uvs.new_tensor([-1.0, 1.0]), - pixel_uvs.new_tensor([1.0, -1.0]), - pixel_uvs, - ) + if texture_maps.device != pixel_uvs.device: + texture_maps = texture_maps.to(pixel_uvs.device) + texels = F.grid_sample( + texture_maps, + pixel_uvs, + mode=self.sampling_mode, + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ) + # texels now has shape (NK, C, H_out, W_out) + texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) + return texels + else: + # We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),fragmenmts.pix_to_face: (N, Ho, Wo, K) + # Get pixel_to_map_ids: (N, K, Ho, Wo) by indexing pix_to_face into maps_ids + N, M, H_in, W_in, C = texture_maps.shape # 3 for RGB - if texture_maps.device != pixel_uvs.device: - texture_maps = texture_maps.to(pixel_uvs.device) - texels = F.grid_sample( - texture_maps, - pixel_uvs, - mode=self.sampling_mode, - align_corners=self.align_corners, - padding_mode=self.padding_mode, - ) - # texels now has shape (NK, C, H_out, W_out) - texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) - return texels + mask = fragments.pix_to_face < 0 + pix_to_face = fragments.pix_to_face.clone() + pix_to_face[mask] = 0 + + pixel_to_map_ids = ( + maps_ids_padded.flatten() + .gather(0, pix_to_face.flatten()) + .view(N, K, H_out, W_out) + ) + + # Normalize between -1 and 1 with M (number of maps) + pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1 + pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4) + pixel_uvs = torch.lerp( + pixel_uvs.new_tensor([-1.0, 1.0]), + pixel_uvs.new_tensor([1.0, -1.0]), + pixel_uvs, + ) + + # N x H_out x W_out x K x 3 + pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4) + # (N, M, H, W, C) -> (N, C, M, H, W) + texture_maps = texture_maps.permute(0, 4, 1, 2, 3) + if texture_maps.device != pixel_uvs.device: + texture_maps = texture_maps.to(pixel_uvs.device) + texels = F.grid_sample( + texture_maps, + pixel_uvms, + mode=self.sampling_mode, + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ) + # (N, C, K, H_out, W_out) -> (N, H_out, W_out, K, C) + texels = texels.permute(0, 3, 4, 2, 1).contiguous() + return texels def faces_verts_textures_packed(self) -> torch.Tensor: """ @@ -1065,25 +1298,41 @@ class TexturesUV(TexturesBase): faces_verts_uvs = _list_to_padded_wrapper( packing_list, pad_value=0.0 ) # Nxmax(Fi)x3x2 - texture_maps = self.maps_padded() # NxHxWxC - texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW - # map to a range of [-1, 1] and flip the y axis faces_verts_uvs = torch.lerp( faces_verts_uvs.new_tensor([-1.0, 1.0]), faces_verts_uvs.new_tensor([1.0, -1.0]), faces_verts_uvs, ) + texture_maps = self.maps_padded() # NxHxWxC or NxMxHxWxC + maps_ids_padded = self.maps_ids_padded() + if maps_ids_padded is None: + texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW + else: + M = texture_maps.shape[1] + # (N, M, H, W, C) -> (N, C, M, H, W) + texture_maps = texture_maps.permute(0, 4, 1, 2, 3) + # expand maps_ids to (N, F, 3, 1) + maps_ids_padded = maps_ids_padded[:, :, None, None].expand(-1, -1, 3, -1) + maps_ids_padded = (2.0 * maps_ids_padded.float() / float(M - 1)) - 1.0 + # (N, F, 3, 2+1) -> (N, 1, F, 3, 3) + faces_verts_uvs = torch.cat( + (faces_verts_uvs, maps_ids_padded), dim=3 + ).unsqueeze(1) + # (N, M, H, W, C) -> (N, C, H, W, M) + # texture_maps = texture_maps.permute(0, 4, 2, 3, 1) textures = F.grid_sample( texture_maps, faces_verts_uvs, mode=self.sampling_mode, align_corners=self.align_corners, padding_mode=self.padding_mode, - ) # NxCxmax(Fi)x3 - - textures = textures.permute(0, 2, 3, 1) # Nxmax(Fi)x3xC + ) # (N, C, max(Fi), 3) + if maps_ids_padded is not None: + textures = textures.squeeze(dim=2) + # (N, C, max(Fi), 3) -> (N, max(Fi), 3, C) + textures = textures.permute(0, 2, 3, 1) textures = _padded_to_list_wrapper( textures, split_size=self._num_faces_per_mesh ) # list of N {Fix3xC} tensors @@ -1102,6 +1351,11 @@ class TexturesUV(TexturesBase): new_tex: TexturesUV object with the combined textures from self and the list `textures`. """ + if self.maps_ids_padded() is not None: + # TODO + raise NotImplementedError( + "join_batch does not support TexturesUV with multiple maps" + ) tex_types_same = all(isinstance(tex, TexturesUV) for tex in textures) if not tex_types_same: raise ValueError("All textures must be of type TexturesUV.") @@ -1137,8 +1391,8 @@ class TexturesUV(TexturesBase): new_tex = self.__class__( maps=maps_list, - verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list, + verts_uvs=verts_uvs_list, padding_mode=self.padding_mode, align_corners=self.align_corners, sampling_mode=self.sampling_mode, @@ -1205,6 +1459,9 @@ class TexturesUV(TexturesBase): _place_map_into_single_map is used to copy the maps into the single map. The merging of verts_uvs and faces_uvs is handled locally in this function. """ + if self.maps_ids_padded() is not None: + # TODO + raise NotImplementedError("join_scene does not support multiple maps.") maps = self.maps_list() heights_and_widths = [] extra_border = 0 if self.align_corners else 2 @@ -1305,8 +1562,8 @@ class TexturesUV(TexturesBase): return self.__class__( maps=[single_map], - verts_uvs=[torch.cat(verts_uvs_merged)], faces_uvs=[torch.cat(faces_uvs_merged)], + verts_uvs=[torch.cat(verts_uvs_merged)], align_corners=self.align_corners, padding_mode=self.padding_mode, sampling_mode=self.sampling_mode, @@ -1326,6 +1583,9 @@ class TexturesUV(TexturesBase): centers: coordinates of points in the texture image - a FloatTensor of shape (V,2) """ + if self.maps_ids_padded() is not None: + # TODO: invent a visualization for the multiple maps case + raise NotImplementedError("This function does not support multiple maps.") if self._N != 1: raise ValueError( "This function only supports plotting textures for one mesh." @@ -1388,7 +1648,9 @@ class TexturesUV(TexturesBase): A "TexturesUV in which faces_uvs_padded, verts_uvs_padded, and maps_padded have length sum(len(faces) for faces in faces_ids_list) """ - + if self.maps_ids_padded() is not None: + # TODO + raise NotImplementedError("This function does not support multiple maps.") if len(faces_ids_list) != len(self.faces_uvs_padded()): raise IndexError( "faces_uvs_padded must be of " "the same length as face_ids_list." @@ -1407,12 +1669,12 @@ class TexturesUV(TexturesBase): sub_maps.append(map_) return self.__class__( - sub_maps, - sub_faces_uvs, - sub_verts_uvs, - self.padding_mode, - self.align_corners, - self.sampling_mode, + maps=sub_maps, + faces_uvs=sub_faces_uvs, + verts_uvs=sub_verts_uvs, + padding_mode=self.padding_mode, + align_corners=self.align_corners, + sampling_mode=self.sampling_mode, ) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 71ffa3e2..9035d982 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -718,6 +718,22 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): verts_uvs=torch.rand(size=(5, 15, 2)), ) + # maps ids are not none but maps doesn't have multiple map indices + with self.assertRaisesRegex(ValueError, "map"): + TexturesUV( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + maps_ids=torch.randint(0, 1, (5, 10), dtype=torch.long), + ) + # maps ids is none but maps have multiple map indices + with self.assertRaisesRegex(ValueError, "map"): + TexturesUV( + maps=torch.ones((5, 2, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + ) + def test_faces_verts_textures(self): device = torch.device("cuda:0") N, V, F, H, W = 2, 5, 12, 8, 8 @@ -755,6 +771,47 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed()) + def test_faces_verts_multiple_map_textures(self): + device = torch.device("cuda:0") + N, M, V, F, H, W = 2, 3, 5, 12, 8, 8 + vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device) + face_uvs = torch.randint( + high=V, size=(N, F, 3), dtype=torch.int64, device=device + ) + map_ids = torch.randint(0, M, (N, F), device=device) + maps = torch.rand((N, M, H, W, 3), dtype=torch.float32, device=device) + + tex = TexturesUV( + maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs, maps_ids=map_ids + ) + + # naive faces_verts_textures + faces_verts_texs = [] + for n in range(N): + temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32) + for f in range(F): + uv0 = vert_uvs[n, face_uvs[n, f, 0]] + uv1 = vert_uvs[n, face_uvs[n, f, 1]] + uv2 = vert_uvs[n, face_uvs[n, f, 2]] + map_id = map_ids[n, f] + + idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2 + idx = idx * 2.0 - 1.0 + imap = maps[n, map_id].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW + imap = torch.flip(imap, [2]) + + texts = torch.nn.functional.grid_sample( + imap, + idx, + align_corners=tex.align_corners, + padding_mode=tex.padding_mode, + ) # 1x3x1x3 + temp[f] = texts[0, :, 0, :].permute(1, 0) + faces_verts_texs.append(temp) + faces_verts_texs = torch.cat(faces_verts_texs, 0) + + self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed()) + def test_clone(self): tex = TexturesUV( maps=torch.ones((5, 16, 16, 3)), @@ -781,6 +838,37 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i]) self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i]) + def test_multiple_maps_clone(self): + tex = TexturesUV( + maps=torch.ones((5, 3, 16, 16, 3)), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + maps_ids=torch.randint(0, 3, (5, 10)), + ) + tex.faces_uvs_list() + tex.verts_uvs_list() + tex_cloned = tex.clone() + self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) + self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) + self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) + self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) + self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) + self.assertClose(tex._maps_padded, tex_cloned._maps_padded) + self.assertSeparate(tex.valid, tex_cloned.valid) + self.assertTrue(tex.valid.eq(tex_cloned.valid).all()) + self.assertSeparate(tex._maps_ids_padded, tex_cloned._maps_ids_padded) + self.assertClose(tex._maps_ids_padded, tex_cloned._maps_ids_padded) + for i in range(tex._N): + self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) + self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) + self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) + self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) + # tex._maps_list is not use anywhere so it's not stored. We call it explicitly + self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i]) + self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i]) + self.assertSeparate(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i]) + self.assertClose(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i]) + def test_detach(self): tex = TexturesUV( maps=torch.ones((5, 16, 16, 3), requires_grad=True), @@ -805,6 +893,35 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): self.assertFalse(tex_detached.maps_list()[i].requires_grad) self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i]) + def test_multiple_maps_detach(self): + tex = TexturesUV( + maps=torch.ones((5, 3, 16, 16, 3), requires_grad=True), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + maps_ids=torch.randint(0, 3, (5, 10)), + ) + tex.faces_uvs_list() + tex.verts_uvs_list() + tex_detached = tex.detach() + self.assertFalse(tex_detached._maps_padded.requires_grad) + self.assertClose(tex._maps_padded, tex_detached._maps_padded) + self.assertFalse(tex_detached._verts_uvs_padded.requires_grad) + self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded) + self.assertFalse(tex_detached._faces_uvs_padded.requires_grad) + self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded) + self.assertFalse(tex_detached._maps_ids_padded.requires_grad) + self.assertClose(tex._maps_ids_padded, tex_detached._maps_ids_padded) + for i in range(tex._N): + self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad) + self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i]) + self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad) + self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i]) + # tex._maps_list is not use anywhere so it's not stored. We call it explicitly + self.assertFalse(tex_detached.maps_list()[i].requires_grad) + self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i]) + self.assertFalse(tex_detached.maps_ids_list()[i].requires_grad) + self.assertClose(tex.maps_ids_list()[i], tex_detached.maps_ids_list()[i]) + def test_extend(self): B = 5 mesh = init_mesh(B, 30, 50) @@ -878,13 +995,15 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): torch.tensor([[0, 1, 2], [3, 4, 5]]), ] # (N, 3, 3) verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)] + maps_ids_given_list = [torch.randint(0, 3, (3,)), torch.randint(0, 3, (2,))] num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list] num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list] tex = TexturesUV( - maps=torch.ones((N, 16, 16, 3)), + maps=torch.ones((N, 3, 16, 16, 3)), faces_uvs=faces_uvs_list, verts_uvs=verts_uvs_list, + maps_ids=maps_ids_given_list, ) # This is set inside Meshes when textures is passed as an input. @@ -898,24 +1017,33 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): faces_list = tex1.faces_uvs_list() faces_padded = tex1.faces_uvs_padded() + maps_ids_list = tex1.maps_ids_list() + maps_ids_padded = tex1.maps_ids_padded() + for f1, f2 in zip(faces_list, faces_uvs_list): self.assertTrue((f1 == f2).all().item()) for f1, f2 in zip(verts_list, verts_uvs_list): self.assertTrue((f1 == f2).all().item()) + for f1, f2 in zip(maps_ids_given_list, maps_ids_list): + self.assertTrue((f1 == f2).all().item()) + self.assertTrue(faces_padded.shape == (2, 3, 3)) self.assertTrue(verts_padded.shape == (2, 9, 2)) + self.assertTrue(maps_ids_padded.shape == (2, 3)) # Case where num_faces_per_mesh is not set and faces_verts_uvs # are initialized with a padded tensor. tex2 = TexturesUV( - maps=torch.ones((N, 16, 16, 3)), + maps=torch.ones((N, 3, 16, 16, 3)), verts_uvs=verts_padded, faces_uvs=faces_padded, + maps_ids=maps_ids_padded, ) faces_list = tex2.faces_uvs_list() verts_list = tex2.verts_uvs_list() + maps_ids_list = tex2.maps_ids_list() for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)): n = num_faces_per_mesh[i] @@ -925,23 +1053,30 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): n = num_verts_per_mesh[i] self.assertTrue((f1[:n] == f2).all().item()) + for i, (f1, f2) in enumerate(zip(maps_ids_list, maps_ids_given_list)): + n = num_faces_per_mesh[i] + self.assertTrue((f1[:n] == f2).all().item()) + def test_to(self): tex = TexturesUV( - maps=torch.ones((5, 16, 16, 3)), + maps=torch.ones((5, 3, 16, 16, 3)), faces_uvs=torch.randint(size=(5, 10, 3), high=15), verts_uvs=torch.rand(size=(5, 15, 2)), + maps_ids=torch.randint(0, 3, (5, 10)), ) device = torch.device("cuda:0") tex = tex.to(device) self.assertEqual(tex._faces_uvs_padded.device, device) self.assertEqual(tex._verts_uvs_padded.device, device) self.assertEqual(tex._maps_padded.device, device) + self.assertEqual(tex._maps_ids_padded.device, device) def test_mesh_to(self): tex_cpu = TexturesUV( - maps=torch.ones((5, 16, 16, 3)), + maps=torch.ones((5, 3, 16, 16, 3)), faces_uvs=torch.randint(size=(5, 10, 3), high=15), verts_uvs=torch.rand(size=(5, 15, 2)), + maps_ids=torch.randint(0, 3, (5, 10)), ) verts = torch.rand(size=(5, 15, 3)) faces = torch.randint(size=(5, 10, 3), high=15) @@ -952,24 +1087,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): self.assertEqual(tex._faces_uvs_padded.device, device) self.assertEqual(tex._verts_uvs_padded.device, device) self.assertEqual(tex._maps_padded.device, device) + self.assertEqual(tex._maps_ids_padded.device, device) self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu) + self.assertEqual(tex_cpu._maps_ids_padded.device, cpu) self.assertEqual(tex_cpu.device, cpu) self.assertEqual(tex.device, device) def test_getitem(self): N = 5 + M = 3 V = 20 F = 10 source = { - "maps": torch.rand(size=(N, 1, 1, 3)), + "maps": torch.rand(size=(N, M, 1, 1, 3)), "faces_uvs": torch.randint(size=(N, F, 3), high=V), "verts_uvs": torch.randn(size=(N, V, 2)), + "maps_ids": torch.randint(0, M, (N, F)), } tex = TexturesUV( maps=source["maps"], faces_uvs=source["faces_uvs"], verts_uvs=source["verts_uvs"], + maps_ids=source["maps_ids"], ) verts = torch.rand(size=(N, V, 3))