mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	TexturesUV multiple maps
Summary: Implements the the TexturesUV with multiple map ids. Reviewed By: bottler Differential Revision: D53944063 fbshipit-source-id: 06c25eb6d69f72db0484f16566dd2ca32a560b82
This commit is contained in:
		
							parent
							
								
									7566530669
								
							
						
					
					
						commit
						38cf0dc1c5
					
				@ -149,6 +149,58 @@ def _pad_texture_maps(
 | 
			
		||||
    return tex_maps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _pad_texture_multiple_maps(
 | 
			
		||||
    multiple_texture_maps: Union[Tuple[torch.Tensor], List[torch.Tensor]],
 | 
			
		||||
    align_corners: bool,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Pad all texture images so they have the same height and width.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        images: list of N tensors of shape (M_i, H_i, W_i, C)
 | 
			
		||||
        M_i : Number of texture maps:w
 | 
			
		||||
 | 
			
		||||
        align_corners: used for interpolation
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        tex_maps: Tensor of shape (N, max_M, max_H, max_W, C)
 | 
			
		||||
    """
 | 
			
		||||
    tex_maps = []
 | 
			
		||||
    max_M = 0
 | 
			
		||||
    max_H = 0
 | 
			
		||||
    max_W = 0
 | 
			
		||||
    C = 0
 | 
			
		||||
    for im in multiple_texture_maps:
 | 
			
		||||
        m, h, w, C = im.shape
 | 
			
		||||
        if m > max_M:
 | 
			
		||||
            max_M = m
 | 
			
		||||
        if h > max_H:
 | 
			
		||||
            max_H = h
 | 
			
		||||
        if w > max_W:
 | 
			
		||||
            max_W = w
 | 
			
		||||
        tex_maps.append(im)
 | 
			
		||||
    max_shape = (max_M, max_H, max_W, C)
 | 
			
		||||
    max_im_shape = (max_H, max_W)
 | 
			
		||||
    for i, tms in enumerate(tex_maps):
 | 
			
		||||
        new_tex_maps = torch.zeros(max_shape)
 | 
			
		||||
        for j in range(tms.shape[0]):
 | 
			
		||||
            im = tms[j]
 | 
			
		||||
            if im.shape[:2] != max_im_shape:
 | 
			
		||||
                image_BCHW = im.permute(2, 0, 1)[None]
 | 
			
		||||
                new_image_BCHW = interpolate(
 | 
			
		||||
                    image_BCHW,
 | 
			
		||||
                    size=max_im_shape,
 | 
			
		||||
                    mode="bilinear",
 | 
			
		||||
                    align_corners=align_corners,
 | 
			
		||||
                )
 | 
			
		||||
                new_tex_maps[j] = new_image_BCHW[0].permute(1, 2, 0)
 | 
			
		||||
            else:
 | 
			
		||||
                new_tex_maps[j] = im
 | 
			
		||||
        tex_maps[i] = new_tex_maps
 | 
			
		||||
    tex_maps = torch.stack(tex_maps, dim=0)  # (num_tex_maps, max_H, max_W, C)
 | 
			
		||||
    return tex_maps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# A base class for defining a batch of textures
 | 
			
		||||
# with helper methods.
 | 
			
		||||
# This is also useful to have so that inside `Meshes`
 | 
			
		||||
@ -199,13 +251,20 @@ class TexturesBase:
 | 
			
		||||
            t = getattr(self, p)
 | 
			
		||||
            if callable(t):
 | 
			
		||||
                t = t()  # class method
 | 
			
		||||
            if isinstance(t, list):
 | 
			
		||||
            if t is None:
 | 
			
		||||
                new_props[p] = None
 | 
			
		||||
            elif isinstance(t, list):
 | 
			
		||||
                if not all(isinstance(elem, (int, float)) for elem in t):
 | 
			
		||||
                    raise ValueError("Extend only supports lists of scalars")
 | 
			
		||||
                t = [[ti] * N for ti in t]
 | 
			
		||||
                new_props[p] = list(itertools.chain(*t))
 | 
			
		||||
            elif torch.is_tensor(t):
 | 
			
		||||
                new_props[p] = t.repeat_interleave(N, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"Property {p} has unsupported type {type(t)}."
 | 
			
		||||
                    "Only tensors and lists are supported."
 | 
			
		||||
                )
 | 
			
		||||
        return new_props
 | 
			
		||||
 | 
			
		||||
    def _getitem(self, index: Union[int, slice], props: List[str]):
 | 
			
		||||
@ -218,7 +277,7 @@ class TexturesBase:
 | 
			
		||||
                t = getattr(self, p)
 | 
			
		||||
                if callable(t):
 | 
			
		||||
                    t = t()  # class method
 | 
			
		||||
                new_props[p] = t[index]
 | 
			
		||||
                new_props[p] = t[index] if t is not None else None
 | 
			
		||||
        elif isinstance(index, list):
 | 
			
		||||
            index = torch.tensor(index)
 | 
			
		||||
        if isinstance(index, torch.Tensor):
 | 
			
		||||
@ -230,8 +289,7 @@ class TexturesBase:
 | 
			
		||||
                t = getattr(self, p)
 | 
			
		||||
                if callable(t):
 | 
			
		||||
                    t = t()  # class method
 | 
			
		||||
                new_props[p] = [t[i] for i in index]
 | 
			
		||||
 | 
			
		||||
                new_props[p] = [t[i] for i in index] if t is not None else None
 | 
			
		||||
        return new_props
 | 
			
		||||
 | 
			
		||||
    def sample_textures(self) -> torch.Tensor:
 | 
			
		||||
@ -644,6 +702,10 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        maps: Union[torch.Tensor, List[torch.Tensor]],
 | 
			
		||||
        faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
			
		||||
        verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
			
		||||
        *,
 | 
			
		||||
        maps_ids: Optional[
 | 
			
		||||
            Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
 | 
			
		||||
        ] = None,
 | 
			
		||||
        padding_mode: str = "border",
 | 
			
		||||
        align_corners: bool = True,
 | 
			
		||||
        sampling_mode: str = "bilinear",
 | 
			
		||||
@ -653,20 +715,33 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        vertex in each face. NOTE: this class only supports one texture map per mesh.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            maps: texture map per mesh. This can either be a list of maps
 | 
			
		||||
              [(H, W, C)] or a padded tensor of shape (N, H, W, C).
 | 
			
		||||
              For RGB, C = 3.
 | 
			
		||||
            maps: Either (1) a texture map per mesh. This can either be a list of maps
 | 
			
		||||
                    [(H, W, C)] or a padded tensor of shape (N, H, W, C).
 | 
			
		||||
                    For RGB, C = 3. In this case maps_ids must be None.
 | 
			
		||||
                Or (2) a set of M texture maps per mesh. This can either be a list of sets
 | 
			
		||||
                    [(M, H, W, C)] or a padded tensor of shape (N, M, H, W, C).
 | 
			
		||||
                    For RGB, C = 3. In this case maps_ids must be provided to
 | 
			
		||||
                    identify which is relevant to each face.
 | 
			
		||||
            faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
 | 
			
		||||
                        for each face
 | 
			
		||||
                    for each face
 | 
			
		||||
            verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
 | 
			
		||||
                        (a FloatTensor with values between 0 and 1).
 | 
			
		||||
                    (a FloatTensor with values between 0 and 1).
 | 
			
		||||
            maps_ids: Used if there are to be multiple maps per face. This can be either a list of map_ids [(F,)]
 | 
			
		||||
                    or a long tensor of shape (N, F) giving the id of the texture map
 | 
			
		||||
                    for each face. If maps_ids is present, the maps has an extra dimension M
 | 
			
		||||
                    (so maps_padded is (N, M, H, W, C) and maps_list has elements of
 | 
			
		||||
                    shape (M, H, W, C)).
 | 
			
		||||
                    Specifically, the color
 | 
			
		||||
                    of a vertex V is given by an average of maps_padded[i, maps_ids[i, f], u, v, :]
 | 
			
		||||
                    over u and v integers adjacent to
 | 
			
		||||
                    _verts_uvs_padded[i, _faces_uvs_padded[i, f, 0], :] .
 | 
			
		||||
            align_corners: If true, the extreme values 0 and 1 for verts_uvs
 | 
			
		||||
                            indicate the centers of the edge pixels in the maps.
 | 
			
		||||
                    indicate the centers of the edge pixels in the maps.
 | 
			
		||||
            padding_mode: padding mode for outside grid values
 | 
			
		||||
                                ("zeros", "border" or "reflection").
 | 
			
		||||
                    ("zeros", "border" or "reflection").
 | 
			
		||||
            sampling_mode: type of interpolation used to sample the texture.
 | 
			
		||||
                            Corresponds to the mode parameter in PyTorch's
 | 
			
		||||
                            grid_sample ("nearest" or "bilinear").
 | 
			
		||||
                    Corresponds to the mode parameter in PyTorch's
 | 
			
		||||
                    grid_sample ("nearest" or "bilinear").
 | 
			
		||||
 | 
			
		||||
        The align_corners and padding_mode arguments correspond to the arguments
 | 
			
		||||
        of the `grid_sample` torch function. There is an informative illustration of
 | 
			
		||||
@ -762,6 +837,8 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Expected verts_uvs to be a tensor or list")
 | 
			
		||||
 | 
			
		||||
        self._maps_ids_padded, self._maps_ids_list = self._format_maps_ids(maps_ids)
 | 
			
		||||
 | 
			
		||||
        if isinstance(maps, (list, tuple)):
 | 
			
		||||
            self._maps_list = maps
 | 
			
		||||
        else:
 | 
			
		||||
@ -770,14 +847,73 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
 | 
			
		||||
        if self._maps_padded.device != self.device:
 | 
			
		||||
            raise ValueError("maps must be on the same device as verts/faces uvs.")
 | 
			
		||||
 | 
			
		||||
        self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
 | 
			
		||||
 | 
			
		||||
    def _format_maps_ids(
 | 
			
		||||
        self,
 | 
			
		||||
        maps_ids: Optional[
 | 
			
		||||
            Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
 | 
			
		||||
        ],
 | 
			
		||||
    ) -> Tuple[
 | 
			
		||||
        Optional[torch.Tensor], Optional[Union[List[torch.Tensor], Tuple[torch.Tensor]]]
 | 
			
		||||
    ]:
 | 
			
		||||
        if maps_ids is None:
 | 
			
		||||
            return None, None
 | 
			
		||||
        elif isinstance(maps_ids, (list, tuple)):
 | 
			
		||||
            for mid in maps_ids:
 | 
			
		||||
                if mid.ndim != 1:
 | 
			
		||||
                    msg = "Expected maps_ids to be of shape (F,); got %r"
 | 
			
		||||
                    raise ValueError(msg % repr(mid.shape))
 | 
			
		||||
            if len(maps_ids) != self._N:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "map_ids, faces_uvs and verts_uvs must have the same batch dimension"
 | 
			
		||||
                )
 | 
			
		||||
            if not all(mid.device == self.device for mid in maps_ids):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "maps_ids and verts/faces uvs must be on the same device"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if not all(
 | 
			
		||||
                mid.shape[0] == nfm
 | 
			
		||||
                for mid, nfm in zip(maps_ids, self._num_faces_per_mesh)
 | 
			
		||||
            ):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "map_ids and faces_uvs must have the same number of faces per mesh"
 | 
			
		||||
                )
 | 
			
		||||
            if not all(mid.device == self.device for mid in maps_ids):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "maps_ids and verts/faces uvs must be on the same device"
 | 
			
		||||
                )
 | 
			
		||||
            if not self._num_faces_per_mesh:
 | 
			
		||||
                return torch.Tensor(), maps_ids
 | 
			
		||||
            return list_to_padded(maps_ids, pad_value=0), maps_ids
 | 
			
		||||
        elif isinstance(maps_ids, torch.Tensor):
 | 
			
		||||
            if maps_ids.ndim != 2 or maps_ids.shape[0] != self._N:
 | 
			
		||||
                msg = "Expected maps_ids to be of shape (N, F); got %r"
 | 
			
		||||
                raise ValueError(msg % repr(maps_ids.shape))
 | 
			
		||||
            maps_ids_padded = maps_ids
 | 
			
		||||
            max_F = max(self._num_faces_per_mesh)
 | 
			
		||||
            if not maps_ids.shape[1] == max_F:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "map_ids and faces_uvs must have the same number of faces per mesh"
 | 
			
		||||
                )
 | 
			
		||||
            if maps_ids.device != self.device:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "maps_ids and verts/faces uvs must be on the same device"
 | 
			
		||||
                )
 | 
			
		||||
            return maps_ids_padded, None
 | 
			
		||||
        raise ValueError("Expected maps_ids to be a tensor or list")
 | 
			
		||||
 | 
			
		||||
    def _format_maps_padded(
 | 
			
		||||
        self, maps: Union[torch.Tensor, List[torch.Tensor]]
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        maps_ids_none = self._maps_ids_padded is None
 | 
			
		||||
        if isinstance(maps, torch.Tensor):
 | 
			
		||||
            if maps.ndim != 4 or maps.shape[0] != self._N:
 | 
			
		||||
            if not maps_ids_none:
 | 
			
		||||
                if maps.ndim != 5 or maps.shape[0] != self._N:
 | 
			
		||||
                    msg = "Expected maps to be of shape (N, M, H, W, C); got %r"
 | 
			
		||||
                    raise ValueError(msg % repr(maps.shape))
 | 
			
		||||
            elif maps.ndim != 4 or maps.shape[0] != self._N:
 | 
			
		||||
                msg = "Expected maps to be of shape (N, H, W, C); got %r"
 | 
			
		||||
                raise ValueError(msg % repr(maps.shape))
 | 
			
		||||
            return maps
 | 
			
		||||
@ -786,15 +922,27 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            if len(maps) != self._N:
 | 
			
		||||
                raise ValueError("Expected one texture map per mesh in the batch.")
 | 
			
		||||
            if self._N > 0:
 | 
			
		||||
                if not all(map.ndim == 3 for map in maps):
 | 
			
		||||
                ndim = 3 if maps_ids_none else 4
 | 
			
		||||
                if not all(map.ndim == ndim for map in maps):
 | 
			
		||||
                    raise ValueError("Invalid number of dimensions in texture maps")
 | 
			
		||||
                if not all(map.shape[2] == maps[0].shape[2] for map in maps):
 | 
			
		||||
                if not all(map.shape[-1] == maps[0].shape[-1] for map in maps):
 | 
			
		||||
                    raise ValueError("Inconsistent number of channels in maps")
 | 
			
		||||
                maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners)
 | 
			
		||||
            else:
 | 
			
		||||
                maps_padded = torch.empty(
 | 
			
		||||
                    (self._N, 0, 0, 3), dtype=torch.float32, device=self.device
 | 
			
		||||
                maps_padded = (
 | 
			
		||||
                    _pad_texture_maps(maps, align_corners=self.align_corners)
 | 
			
		||||
                    if maps_ids_none
 | 
			
		||||
                    else _pad_texture_multiple_maps(
 | 
			
		||||
                        maps, align_corners=self.align_corners
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                if maps_ids_none:
 | 
			
		||||
                    maps_padded = torch.empty(
 | 
			
		||||
                        (self._N, 0, 0, 3), dtype=torch.float32, device=self.device
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    maps_padded = torch.empty(
 | 
			
		||||
                        (self._N, 0, 0, 0, 3), dtype=torch.float32, device=self.device
 | 
			
		||||
                    )
 | 
			
		||||
            return maps_padded
 | 
			
		||||
 | 
			
		||||
        raise ValueError("Expected maps to be a tensor or list of tensors.")
 | 
			
		||||
@ -804,6 +952,11 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            self.maps_padded().clone(),
 | 
			
		||||
            self.faces_uvs_padded().clone(),
 | 
			
		||||
            self.verts_uvs_padded().clone(),
 | 
			
		||||
            maps_ids=(
 | 
			
		||||
                self._maps_ids_padded.clone()
 | 
			
		||||
                if self._maps_ids_padded is not None
 | 
			
		||||
                else None
 | 
			
		||||
            ),
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -814,6 +967,8 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
 | 
			
		||||
        if self._faces_uvs_list is not None:
 | 
			
		||||
            tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
 | 
			
		||||
        if self._maps_ids_list is not None:
 | 
			
		||||
            tex._maps_ids_list = [f.clone() for f in self._maps_ids_list]
 | 
			
		||||
        num_faces = (
 | 
			
		||||
            self._num_faces_per_mesh.clone()
 | 
			
		||||
            if torch.is_tensor(self._num_faces_per_mesh)
 | 
			
		||||
@ -828,6 +983,11 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            self.maps_padded().detach(),
 | 
			
		||||
            self.faces_uvs_padded().detach(),
 | 
			
		||||
            self.verts_uvs_padded().detach(),
 | 
			
		||||
            maps_ids=(
 | 
			
		||||
                self._maps_ids_padded.detach()
 | 
			
		||||
                if self._maps_ids_padded is not None
 | 
			
		||||
                else None
 | 
			
		||||
            ),
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -838,6 +998,8 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
 | 
			
		||||
        if self._faces_uvs_list is not None:
 | 
			
		||||
            tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
 | 
			
		||||
        if self._maps_ids_list is not None:
 | 
			
		||||
            tex._maps_ids_list = [mi.detach() for mi in self._maps_ids_list]
 | 
			
		||||
        num_faces = (
 | 
			
		||||
            self._num_faces_per_mesh.detach()
 | 
			
		||||
            if torch.is_tensor(self._num_faces_per_mesh)
 | 
			
		||||
@ -848,27 +1010,44 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        return tex
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index) -> "TexturesUV":
 | 
			
		||||
        props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
 | 
			
		||||
        props = [
 | 
			
		||||
            "faces_uvs_list",
 | 
			
		||||
            "verts_uvs_list",
 | 
			
		||||
            "maps_list",
 | 
			
		||||
            "maps_ids_list",
 | 
			
		||||
            "_num_faces_per_mesh",
 | 
			
		||||
        ]
 | 
			
		||||
        new_props = self._getitem(index, props)
 | 
			
		||||
        faces_uvs = new_props["faces_uvs_list"]
 | 
			
		||||
        verts_uvs = new_props["verts_uvs_list"]
 | 
			
		||||
        maps = new_props["maps_list"]
 | 
			
		||||
        maps_ids = new_props["maps_ids_list"]
 | 
			
		||||
 | 
			
		||||
        # if index has multiple values then faces/verts/maps may be a list of tensors
 | 
			
		||||
        if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
 | 
			
		||||
            if maps_ids is not None and not isinstance(maps_ids, (list, tuple)):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Maps ids are  not in the correct format expected list or tuple"
 | 
			
		||||
                )
 | 
			
		||||
            new_tex = self.__class__(
 | 
			
		||||
                faces_uvs=faces_uvs,
 | 
			
		||||
                verts_uvs=verts_uvs,
 | 
			
		||||
                maps=maps,
 | 
			
		||||
                maps_ids=maps_ids,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                sampling_mode=self.sampling_mode,
 | 
			
		||||
            )
 | 
			
		||||
        elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
 | 
			
		||||
            if maps_ids is not None and not torch.is_tensor(maps_ids):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Maps ids are not in the correct format expected tensor"
 | 
			
		||||
                )
 | 
			
		||||
            new_tex = self.__class__(
 | 
			
		||||
                faces_uvs=[faces_uvs],
 | 
			
		||||
                verts_uvs=[verts_uvs],
 | 
			
		||||
                maps=[maps],
 | 
			
		||||
                maps_ids=[maps_ids] if maps_ids is not None else None,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -927,6 +1106,17 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
 | 
			
		||||
        return self._verts_uvs_list
 | 
			
		||||
 | 
			
		||||
    def maps_ids_padded(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        return self._maps_ids_padded
 | 
			
		||||
 | 
			
		||||
    def maps_ids_list(self) -> Optional[List[torch.Tensor]]:
 | 
			
		||||
        if self._maps_ids_list is not None:
 | 
			
		||||
            return self._maps_ids_list
 | 
			
		||||
        elif self._maps_ids_padded is not None:
 | 
			
		||||
            return self._maps_ids_padded.unbind(0)
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    # Currently only the padded maps are used.
 | 
			
		||||
    def maps_padded(self) -> torch.Tensor:
 | 
			
		||||
        return self._maps_padded
 | 
			
		||||
@ -943,6 +1133,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                "maps_padded",
 | 
			
		||||
                "verts_uvs_padded",
 | 
			
		||||
                "faces_uvs_padded",
 | 
			
		||||
                "maps_ids_padded",
 | 
			
		||||
                "_num_faces_per_mesh",
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
@ -950,6 +1141,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            maps=new_props["maps_padded"],
 | 
			
		||||
            faces_uvs=new_props["faces_uvs_padded"],
 | 
			
		||||
            verts_uvs=new_props["verts_uvs_padded"],
 | 
			
		||||
            maps_ids=new_props["maps_ids_padded"],
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -992,7 +1184,6 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
 | 
			
		||||
            ]
 | 
			
		||||
            faces_verts_uvs = torch.cat(packing_list)
 | 
			
		||||
        texture_maps = self.maps_padded()
 | 
			
		||||
 | 
			
		||||
        # pixel_uvs: (N, H, W, K, 2)
 | 
			
		||||
        pixel_uvs = interpolate_face_attributes(
 | 
			
		||||
@ -1000,49 +1191,91 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        N, H_out, W_out, K = fragments.pix_to_face.shape
 | 
			
		||||
        N, H_in, W_in, C = texture_maps.shape  # 3 for RGB
 | 
			
		||||
 | 
			
		||||
        # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
 | 
			
		||||
        pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
 | 
			
		||||
        texture_maps = self.maps_padded()
 | 
			
		||||
        maps_ids_padded = self.maps_ids_padded()
 | 
			
		||||
        if maps_ids_padded is None:
 | 
			
		||||
            # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2)
 | 
			
		||||
            pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2)
 | 
			
		||||
            N, H_in, W_in, C = texture_maps.shape  # 3 for RGB
 | 
			
		||||
 | 
			
		||||
        # textures.map:
 | 
			
		||||
        #   (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
 | 
			
		||||
        #   -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
 | 
			
		||||
        texture_maps = (
 | 
			
		||||
            texture_maps.permute(0, 3, 1, 2)[None, ...]
 | 
			
		||||
            .expand(K, -1, -1, -1, -1)
 | 
			
		||||
            .transpose(0, 1)
 | 
			
		||||
            .reshape(N * K, C, H_in, W_in)
 | 
			
		||||
        )
 | 
			
		||||
            # textures.map:
 | 
			
		||||
            #   (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W)
 | 
			
		||||
            #   -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W)
 | 
			
		||||
            texture_maps = (
 | 
			
		||||
                texture_maps.permute(0, 3, 1, 2)[None, ...]
 | 
			
		||||
                .expand(K, -1, -1, -1, -1)
 | 
			
		||||
                .transpose(0, 1)
 | 
			
		||||
                .reshape(N * K, C, H_in, W_in)
 | 
			
		||||
            )
 | 
			
		||||
            # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
 | 
			
		||||
            # Now need to format the pixel uvs and the texture map correctly!
 | 
			
		||||
            # From pytorch docs, grid_sample takes `grid` and `input`:
 | 
			
		||||
            #   grid specifies the sampling pixel locations normalized by
 | 
			
		||||
            #   the input spatial dimensions It should have most
 | 
			
		||||
            #   values in the range of [-1, 1]. Values x = -1, y = -1
 | 
			
		||||
            #   is the left-top pixel of input, and values x = 1, y = 1 is the
 | 
			
		||||
            #   right-bottom pixel of input.
 | 
			
		||||
 | 
			
		||||
        # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2)
 | 
			
		||||
        # Now need to format the pixel uvs and the texture map correctly!
 | 
			
		||||
        # From pytorch docs, grid_sample takes `grid` and `input`:
 | 
			
		||||
        #   grid specifies the sampling pixel locations normalized by
 | 
			
		||||
        #   the input spatial dimensions It should have most
 | 
			
		||||
        #   values in the range of [-1, 1]. Values x = -1, y = -1
 | 
			
		||||
        #   is the left-top pixel of input, and values x = 1, y = 1 is the
 | 
			
		||||
        #   right-bottom pixel of input.
 | 
			
		||||
            # map to a range of [-1, 1] and flip the y axis
 | 
			
		||||
            pixel_uvs = torch.lerp(
 | 
			
		||||
                pixel_uvs.new_tensor([-1.0, 1.0]),
 | 
			
		||||
                pixel_uvs.new_tensor([1.0, -1.0]),
 | 
			
		||||
                pixel_uvs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # map to a range of [-1, 1] and flip the y axis
 | 
			
		||||
        pixel_uvs = torch.lerp(
 | 
			
		||||
            pixel_uvs.new_tensor([-1.0, 1.0]),
 | 
			
		||||
            pixel_uvs.new_tensor([1.0, -1.0]),
 | 
			
		||||
            pixel_uvs,
 | 
			
		||||
        )
 | 
			
		||||
            if texture_maps.device != pixel_uvs.device:
 | 
			
		||||
                texture_maps = texture_maps.to(pixel_uvs.device)
 | 
			
		||||
            texels = F.grid_sample(
 | 
			
		||||
                texture_maps,
 | 
			
		||||
                pixel_uvs,
 | 
			
		||||
                mode=self.sampling_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
            )
 | 
			
		||||
            # texels now has shape (NK, C, H_out, W_out)
 | 
			
		||||
            texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
 | 
			
		||||
            return texels
 | 
			
		||||
        else:
 | 
			
		||||
            # We have maps_ids_padded: (N, F), textures_map: (N, M, Hi, Wi, C),fragmenmts.pix_to_face: (N, Ho, Wo, K)
 | 
			
		||||
            # Get pixel_to_map_ids: (N, K, Ho, Wo) by indexing pix_to_face into maps_ids
 | 
			
		||||
            N, M, H_in, W_in, C = texture_maps.shape  # 3 for RGB
 | 
			
		||||
 | 
			
		||||
        if texture_maps.device != pixel_uvs.device:
 | 
			
		||||
            texture_maps = texture_maps.to(pixel_uvs.device)
 | 
			
		||||
        texels = F.grid_sample(
 | 
			
		||||
            texture_maps,
 | 
			
		||||
            pixel_uvs,
 | 
			
		||||
            mode=self.sampling_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
        )
 | 
			
		||||
        # texels now has shape (NK, C, H_out, W_out)
 | 
			
		||||
        texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
 | 
			
		||||
        return texels
 | 
			
		||||
            mask = fragments.pix_to_face < 0
 | 
			
		||||
            pix_to_face = fragments.pix_to_face.clone()
 | 
			
		||||
            pix_to_face[mask] = 0
 | 
			
		||||
 | 
			
		||||
            pixel_to_map_ids = (
 | 
			
		||||
                maps_ids_padded.flatten()
 | 
			
		||||
                .gather(0, pix_to_face.flatten())
 | 
			
		||||
                .view(N, K, H_out, W_out)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Normalize between -1 and 1 with M (number of maps)
 | 
			
		||||
            pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1
 | 
			
		||||
            pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4)
 | 
			
		||||
            pixel_uvs = torch.lerp(
 | 
			
		||||
                pixel_uvs.new_tensor([-1.0, 1.0]),
 | 
			
		||||
                pixel_uvs.new_tensor([1.0, -1.0]),
 | 
			
		||||
                pixel_uvs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # N x H_out x W_out x K x 3
 | 
			
		||||
            pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4)
 | 
			
		||||
            # (N, M, H, W, C) -> (N, C, M, H, W)
 | 
			
		||||
            texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
 | 
			
		||||
            if texture_maps.device != pixel_uvs.device:
 | 
			
		||||
                texture_maps = texture_maps.to(pixel_uvs.device)
 | 
			
		||||
            texels = F.grid_sample(
 | 
			
		||||
                texture_maps,
 | 
			
		||||
                pixel_uvms,
 | 
			
		||||
                mode=self.sampling_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
            )
 | 
			
		||||
            # (N, C, K, H_out, W_out) -> (N, H_out, W_out, K, C)
 | 
			
		||||
            texels = texels.permute(0, 3, 4, 2, 1).contiguous()
 | 
			
		||||
            return texels
 | 
			
		||||
 | 
			
		||||
    def faces_verts_textures_packed(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
@ -1065,25 +1298,41 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            faces_verts_uvs = _list_to_padded_wrapper(
 | 
			
		||||
                packing_list, pad_value=0.0
 | 
			
		||||
            )  # Nxmax(Fi)x3x2
 | 
			
		||||
        texture_maps = self.maps_padded()  # NxHxWxC
 | 
			
		||||
        texture_maps = texture_maps.permute(0, 3, 1, 2)  # NxCxHxW
 | 
			
		||||
 | 
			
		||||
        # map to a range of [-1, 1] and flip the y axis
 | 
			
		||||
        faces_verts_uvs = torch.lerp(
 | 
			
		||||
            faces_verts_uvs.new_tensor([-1.0, 1.0]),
 | 
			
		||||
            faces_verts_uvs.new_tensor([1.0, -1.0]),
 | 
			
		||||
            faces_verts_uvs,
 | 
			
		||||
        )
 | 
			
		||||
        texture_maps = self.maps_padded()  # NxHxWxC or NxMxHxWxC
 | 
			
		||||
        maps_ids_padded = self.maps_ids_padded()
 | 
			
		||||
        if maps_ids_padded is None:
 | 
			
		||||
            texture_maps = texture_maps.permute(0, 3, 1, 2)  # NxCxHxW
 | 
			
		||||
        else:
 | 
			
		||||
            M = texture_maps.shape[1]
 | 
			
		||||
            # (N, M, H, W, C) -> (N, C, M, H, W)
 | 
			
		||||
            texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
 | 
			
		||||
            # expand maps_ids to (N, F, 3, 1)
 | 
			
		||||
            maps_ids_padded = maps_ids_padded[:, :, None, None].expand(-1, -1, 3, -1)
 | 
			
		||||
            maps_ids_padded = (2.0 * maps_ids_padded.float() / float(M - 1)) - 1.0
 | 
			
		||||
 | 
			
		||||
            # (N, F, 3, 2+1) -> (N, 1, F, 3, 3)
 | 
			
		||||
            faces_verts_uvs = torch.cat(
 | 
			
		||||
                (faces_verts_uvs, maps_ids_padded), dim=3
 | 
			
		||||
            ).unsqueeze(1)
 | 
			
		||||
            # (N, M, H, W, C) -> (N, C, H, W, M)
 | 
			
		||||
            # texture_maps = texture_maps.permute(0, 4, 2, 3, 1)
 | 
			
		||||
        textures = F.grid_sample(
 | 
			
		||||
            texture_maps,
 | 
			
		||||
            faces_verts_uvs,
 | 
			
		||||
            mode=self.sampling_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
        )  # NxCxmax(Fi)x3
 | 
			
		||||
 | 
			
		||||
        textures = textures.permute(0, 2, 3, 1)  # Nxmax(Fi)x3xC
 | 
			
		||||
        )  # (N, C, max(Fi), 3)
 | 
			
		||||
        if maps_ids_padded is not None:
 | 
			
		||||
            textures = textures.squeeze(dim=2)
 | 
			
		||||
        # (N, C, max(Fi), 3) -> (N, max(Fi), 3, C)
 | 
			
		||||
        textures = textures.permute(0, 2, 3, 1)
 | 
			
		||||
        textures = _padded_to_list_wrapper(
 | 
			
		||||
            textures, split_size=self._num_faces_per_mesh
 | 
			
		||||
        )  # list of N {Fix3xC} tensors
 | 
			
		||||
@ -1102,6 +1351,11 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            new_tex: TexturesUV object with the combined
 | 
			
		||||
            textures from self and the list `textures`.
 | 
			
		||||
        """
 | 
			
		||||
        if self.maps_ids_padded() is not None:
 | 
			
		||||
            # TODO
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "join_batch does not support TexturesUV with multiple maps"
 | 
			
		||||
            )
 | 
			
		||||
        tex_types_same = all(isinstance(tex, TexturesUV) for tex in textures)
 | 
			
		||||
        if not tex_types_same:
 | 
			
		||||
            raise ValueError("All textures must be of type TexturesUV.")
 | 
			
		||||
@ -1137,8 +1391,8 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
 | 
			
		||||
        new_tex = self.__class__(
 | 
			
		||||
            maps=maps_list,
 | 
			
		||||
            verts_uvs=verts_uvs_list,
 | 
			
		||||
            faces_uvs=faces_uvs_list,
 | 
			
		||||
            verts_uvs=verts_uvs_list,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -1205,6 +1459,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        _place_map_into_single_map is used to copy the maps into the single map.
 | 
			
		||||
        The merging of verts_uvs and faces_uvs is handled locally in this function.
 | 
			
		||||
        """
 | 
			
		||||
        if self.maps_ids_padded() is not None:
 | 
			
		||||
            # TODO
 | 
			
		||||
            raise NotImplementedError("join_scene does not support multiple maps.")
 | 
			
		||||
        maps = self.maps_list()
 | 
			
		||||
        heights_and_widths = []
 | 
			
		||||
        extra_border = 0 if self.align_corners else 2
 | 
			
		||||
@ -1305,8 +1562,8 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
 | 
			
		||||
        return self.__class__(
 | 
			
		||||
            maps=[single_map],
 | 
			
		||||
            verts_uvs=[torch.cat(verts_uvs_merged)],
 | 
			
		||||
            faces_uvs=[torch.cat(faces_uvs_merged)],
 | 
			
		||||
            verts_uvs=[torch.cat(verts_uvs_merged)],
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
@ -1326,6 +1583,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            centers: coordinates of points in the texture image
 | 
			
		||||
                - a FloatTensor of shape (V,2)
 | 
			
		||||
        """
 | 
			
		||||
        if self.maps_ids_padded() is not None:
 | 
			
		||||
            # TODO: invent a visualization for the multiple maps case
 | 
			
		||||
            raise NotImplementedError("This function does not support multiple maps.")
 | 
			
		||||
        if self._N != 1:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This function only supports plotting textures for one mesh."
 | 
			
		||||
@ -1388,7 +1648,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            A  "TexturesUV in which faces_uvs_padded, verts_uvs_padded, and maps_padded
 | 
			
		||||
            have length sum(len(faces) for faces in faces_ids_list)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if self.maps_ids_padded() is not None:
 | 
			
		||||
            # TODO
 | 
			
		||||
            raise NotImplementedError("This function does not support multiple maps.")
 | 
			
		||||
        if len(faces_ids_list) != len(self.faces_uvs_padded()):
 | 
			
		||||
            raise IndexError(
 | 
			
		||||
                "faces_uvs_padded must be of " "the same length as face_ids_list."
 | 
			
		||||
@ -1407,12 +1669,12 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                sub_maps.append(map_)
 | 
			
		||||
 | 
			
		||||
        return self.__class__(
 | 
			
		||||
            sub_maps,
 | 
			
		||||
            sub_faces_uvs,
 | 
			
		||||
            sub_verts_uvs,
 | 
			
		||||
            self.padding_mode,
 | 
			
		||||
            self.align_corners,
 | 
			
		||||
            self.sampling_mode,
 | 
			
		||||
            maps=sub_maps,
 | 
			
		||||
            faces_uvs=sub_faces_uvs,
 | 
			
		||||
            verts_uvs=sub_verts_uvs,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -718,6 +718,22 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # maps ids are not none but maps doesn't have multiple map indices
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "map"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
                maps_ids=torch.randint(0, 1, (5, 10), dtype=torch.long),
 | 
			
		||||
            )
 | 
			
		||||
        # maps ids is none but maps have multiple map indices
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "map"):
 | 
			
		||||
            TexturesUV(
 | 
			
		||||
                maps=torch.ones((5, 2, 16, 16, 3)),
 | 
			
		||||
                faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
                verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_faces_verts_textures(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        N, V, F, H, W = 2, 5, 12, 8, 8
 | 
			
		||||
@ -755,6 +771,47 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
 | 
			
		||||
 | 
			
		||||
    def test_faces_verts_multiple_map_textures(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        N, M, V, F, H, W = 2, 3, 5, 12, 8, 8
 | 
			
		||||
        vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device)
 | 
			
		||||
        face_uvs = torch.randint(
 | 
			
		||||
            high=V, size=(N, F, 3), dtype=torch.int64, device=device
 | 
			
		||||
        )
 | 
			
		||||
        map_ids = torch.randint(0, M, (N, F), device=device)
 | 
			
		||||
        maps = torch.rand((N, M, H, W, 3), dtype=torch.float32, device=device)
 | 
			
		||||
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs, maps_ids=map_ids
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # naive faces_verts_textures
 | 
			
		||||
        faces_verts_texs = []
 | 
			
		||||
        for n in range(N):
 | 
			
		||||
            temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32)
 | 
			
		||||
            for f in range(F):
 | 
			
		||||
                uv0 = vert_uvs[n, face_uvs[n, f, 0]]
 | 
			
		||||
                uv1 = vert_uvs[n, face_uvs[n, f, 1]]
 | 
			
		||||
                uv2 = vert_uvs[n, face_uvs[n, f, 2]]
 | 
			
		||||
                map_id = map_ids[n, f]
 | 
			
		||||
 | 
			
		||||
                idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2)  # 1x1x3x2
 | 
			
		||||
                idx = idx * 2.0 - 1.0
 | 
			
		||||
                imap = maps[n, map_id].view(1, H, W, 3).permute(0, 3, 1, 2)  # 1x3xHxW
 | 
			
		||||
                imap = torch.flip(imap, [2])
 | 
			
		||||
 | 
			
		||||
                texts = torch.nn.functional.grid_sample(
 | 
			
		||||
                    imap,
 | 
			
		||||
                    idx,
 | 
			
		||||
                    align_corners=tex.align_corners,
 | 
			
		||||
                    padding_mode=tex.padding_mode,
 | 
			
		||||
                )  # 1x3x1x3
 | 
			
		||||
                temp[f] = texts[0, :, 0, :].permute(1, 0)
 | 
			
		||||
            faces_verts_texs.append(temp)
 | 
			
		||||
        faces_verts_texs = torch.cat(faces_verts_texs, 0)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed())
 | 
			
		||||
 | 
			
		||||
    def test_clone(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
@ -781,6 +838,37 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
 | 
			
		||||
            self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
 | 
			
		||||
 | 
			
		||||
    def test_multiple_maps_clone(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 3, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            maps_ids=torch.randint(0, 3, (5, 10)),
 | 
			
		||||
        )
 | 
			
		||||
        tex.faces_uvs_list()
 | 
			
		||||
        tex.verts_uvs_list()
 | 
			
		||||
        tex_cloned = tex.clone()
 | 
			
		||||
        self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
 | 
			
		||||
        self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
 | 
			
		||||
        self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
 | 
			
		||||
        self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
 | 
			
		||||
        self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
 | 
			
		||||
        self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
 | 
			
		||||
        self.assertSeparate(tex.valid, tex_cloned.valid)
 | 
			
		||||
        self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
 | 
			
		||||
        self.assertSeparate(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
 | 
			
		||||
        self.assertClose(tex._maps_ids_padded, tex_cloned._maps_ids_padded)
 | 
			
		||||
        for i in range(tex._N):
 | 
			
		||||
            self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
 | 
			
		||||
            self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
 | 
			
		||||
            self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
 | 
			
		||||
            self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
 | 
			
		||||
            # tex._maps_list is not use anywhere so it's not stored. We call it explicitly
 | 
			
		||||
            self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
 | 
			
		||||
            self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
 | 
			
		||||
            self.assertSeparate(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
 | 
			
		||||
            self.assertClose(tex.maps_ids_list()[i], tex_cloned.maps_ids_list()[i])
 | 
			
		||||
 | 
			
		||||
    def test_detach(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3), requires_grad=True),
 | 
			
		||||
@ -805,6 +893,35 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertFalse(tex_detached.maps_list()[i].requires_grad)
 | 
			
		||||
            self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
 | 
			
		||||
 | 
			
		||||
    def test_multiple_maps_detach(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 3, 16, 16, 3), requires_grad=True),
 | 
			
		||||
            faces_uvs=torch.rand(size=(5, 10, 3)),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            maps_ids=torch.randint(0, 3, (5, 10)),
 | 
			
		||||
        )
 | 
			
		||||
        tex.faces_uvs_list()
 | 
			
		||||
        tex.verts_uvs_list()
 | 
			
		||||
        tex_detached = tex.detach()
 | 
			
		||||
        self.assertFalse(tex_detached._maps_padded.requires_grad)
 | 
			
		||||
        self.assertClose(tex._maps_padded, tex_detached._maps_padded)
 | 
			
		||||
        self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
 | 
			
		||||
        self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
 | 
			
		||||
        self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
 | 
			
		||||
        self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
 | 
			
		||||
        self.assertFalse(tex_detached._maps_ids_padded.requires_grad)
 | 
			
		||||
        self.assertClose(tex._maps_ids_padded, tex_detached._maps_ids_padded)
 | 
			
		||||
        for i in range(tex._N):
 | 
			
		||||
            self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
 | 
			
		||||
            self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
 | 
			
		||||
            self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
 | 
			
		||||
            self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
 | 
			
		||||
            # tex._maps_list is not use anywhere so it's not stored. We call it explicitly
 | 
			
		||||
            self.assertFalse(tex_detached.maps_list()[i].requires_grad)
 | 
			
		||||
            self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
 | 
			
		||||
            self.assertFalse(tex_detached.maps_ids_list()[i].requires_grad)
 | 
			
		||||
            self.assertClose(tex.maps_ids_list()[i], tex_detached.maps_ids_list()[i])
 | 
			
		||||
 | 
			
		||||
    def test_extend(self):
 | 
			
		||||
        B = 5
 | 
			
		||||
        mesh = init_mesh(B, 30, 50)
 | 
			
		||||
@ -878,13 +995,15 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            torch.tensor([[0, 1, 2], [3, 4, 5]]),
 | 
			
		||||
        ]  # (N, 3, 3)
 | 
			
		||||
        verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
 | 
			
		||||
        maps_ids_given_list = [torch.randint(0, 3, (3,)), torch.randint(0, 3, (2,))]
 | 
			
		||||
 | 
			
		||||
        num_faces_per_mesh = [f.shape[0] for f in faces_uvs_list]
 | 
			
		||||
        num_verts_per_mesh = [v.shape[0] for v in verts_uvs_list]
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((N, 16, 16, 3)),
 | 
			
		||||
            maps=torch.ones((N, 3, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=faces_uvs_list,
 | 
			
		||||
            verts_uvs=verts_uvs_list,
 | 
			
		||||
            maps_ids=maps_ids_given_list,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # This is set inside Meshes when textures is passed as an input.
 | 
			
		||||
@ -898,24 +1017,33 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        faces_list = tex1.faces_uvs_list()
 | 
			
		||||
        faces_padded = tex1.faces_uvs_padded()
 | 
			
		||||
 | 
			
		||||
        maps_ids_list = tex1.maps_ids_list()
 | 
			
		||||
        maps_ids_padded = tex1.maps_ids_padded()
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(faces_list, faces_uvs_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(verts_list, verts_uvs_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for f1, f2 in zip(maps_ids_given_list, maps_ids_list):
 | 
			
		||||
            self.assertTrue((f1 == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(faces_padded.shape == (2, 3, 3))
 | 
			
		||||
        self.assertTrue(verts_padded.shape == (2, 9, 2))
 | 
			
		||||
        self.assertTrue(maps_ids_padded.shape == (2, 3))
 | 
			
		||||
 | 
			
		||||
        # Case where num_faces_per_mesh is not set and faces_verts_uvs
 | 
			
		||||
        # are initialized with a padded tensor.
 | 
			
		||||
        tex2 = TexturesUV(
 | 
			
		||||
            maps=torch.ones((N, 16, 16, 3)),
 | 
			
		||||
            maps=torch.ones((N, 3, 16, 16, 3)),
 | 
			
		||||
            verts_uvs=verts_padded,
 | 
			
		||||
            faces_uvs=faces_padded,
 | 
			
		||||
            maps_ids=maps_ids_padded,
 | 
			
		||||
        )
 | 
			
		||||
        faces_list = tex2.faces_uvs_list()
 | 
			
		||||
        verts_list = tex2.verts_uvs_list()
 | 
			
		||||
        maps_ids_list = tex2.maps_ids_list()
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
 | 
			
		||||
            n = num_faces_per_mesh[i]
 | 
			
		||||
@ -925,23 +1053,30 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            n = num_verts_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
        for i, (f1, f2) in enumerate(zip(maps_ids_list, maps_ids_given_list)):
 | 
			
		||||
            n = num_faces_per_mesh[i]
 | 
			
		||||
            self.assertTrue((f1[:n] == f2).all().item())
 | 
			
		||||
 | 
			
		||||
    def test_to(self):
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
            maps=torch.ones((5, 3, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(5, 10, 3), high=15),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            maps_ids=torch.randint(0, 3, (5, 10)),
 | 
			
		||||
        )
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        tex = tex.to(device)
 | 
			
		||||
        self.assertEqual(tex._faces_uvs_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._verts_uvs_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._maps_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._maps_ids_padded.device, device)
 | 
			
		||||
 | 
			
		||||
    def test_mesh_to(self):
 | 
			
		||||
        tex_cpu = TexturesUV(
 | 
			
		||||
            maps=torch.ones((5, 16, 16, 3)),
 | 
			
		||||
            maps=torch.ones((5, 3, 16, 16, 3)),
 | 
			
		||||
            faces_uvs=torch.randint(size=(5, 10, 3), high=15),
 | 
			
		||||
            verts_uvs=torch.rand(size=(5, 15, 2)),
 | 
			
		||||
            maps_ids=torch.randint(0, 3, (5, 10)),
 | 
			
		||||
        )
 | 
			
		||||
        verts = torch.rand(size=(5, 15, 3))
 | 
			
		||||
        faces = torch.randint(size=(5, 10, 3), high=15)
 | 
			
		||||
@ -952,24 +1087,29 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertEqual(tex._faces_uvs_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._verts_uvs_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._maps_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex._maps_ids_padded.device, device)
 | 
			
		||||
        self.assertEqual(tex_cpu._verts_uvs_padded.device, cpu)
 | 
			
		||||
        self.assertEqual(tex_cpu._maps_ids_padded.device, cpu)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(tex_cpu.device, cpu)
 | 
			
		||||
        self.assertEqual(tex.device, device)
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        M = 3
 | 
			
		||||
        V = 20
 | 
			
		||||
        F = 10
 | 
			
		||||
        source = {
 | 
			
		||||
            "maps": torch.rand(size=(N, 1, 1, 3)),
 | 
			
		||||
            "maps": torch.rand(size=(N, M, 1, 1, 3)),
 | 
			
		||||
            "faces_uvs": torch.randint(size=(N, F, 3), high=V),
 | 
			
		||||
            "verts_uvs": torch.randn(size=(N, V, 2)),
 | 
			
		||||
            "maps_ids": torch.randint(0, M, (N, F)),
 | 
			
		||||
        }
 | 
			
		||||
        tex = TexturesUV(
 | 
			
		||||
            maps=source["maps"],
 | 
			
		||||
            faces_uvs=source["faces_uvs"],
 | 
			
		||||
            verts_uvs=source["verts_uvs"],
 | 
			
		||||
            maps_ids=source["maps_ids"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        verts = torch.rand(size=(N, V, 3))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user