diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 66b272f0..bddf624e 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -596,6 +596,7 @@ class TexturesUV(TexturesBase): verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], padding_mode: str = "border", align_corners: bool = True, + sampling_mode: str = "bilinear", ) -> None: """ Textures are represented as a per mesh texture map and uv coordinates for each @@ -613,6 +614,9 @@ class TexturesUV(TexturesBase): indicate the centers of the edge pixels in the maps. padding_mode: padding mode for outside grid values ("zeros", "border" or "reflection"). + sampling_mode: type of interpolation used to sample the texture. + Corresponds to the mode parameter in PyTorch's + grid_sample ("nearest" or "bilinear"). The align_corners and padding_mode arguments correspond to the arguments of the `grid_sample` torch function. There is an informative illustration of @@ -641,6 +645,7 @@ class TexturesUV(TexturesBase): """ self.padding_mode = padding_mode self.align_corners = align_corners + self.sampling_mode = sampling_mode if isinstance(faces_uvs, (list, tuple)): for fv in faces_uvs: if fv.ndim != 2 or fv.shape[-1] != 3: @@ -749,6 +754,9 @@ class TexturesUV(TexturesBase): self.maps_padded().clone(), self.faces_uvs_padded().clone(), self.verts_uvs_padded().clone(), + align_corners=self.align_corners, + padding_mode=self.padding_mode, + sampling_mode=self.sampling_mode, ) if self._maps_list is not None: tex._maps_list = [m.clone() for m in self._maps_list] @@ -770,6 +778,9 @@ class TexturesUV(TexturesBase): self.maps_padded().detach(), self.faces_uvs_padded().detach(), self.verts_uvs_padded().detach(), + align_corners=self.align_corners, + padding_mode=self.padding_mode, + sampling_mode=self.sampling_mode, ) if self._maps_list is not None: tex._maps_list = [m.detach() for m in self._maps_list] @@ -801,6 +812,7 @@ class TexturesUV(TexturesBase): maps=maps, padding_mode=self.padding_mode, align_corners=self.align_corners, + sampling_mode=self.sampling_mode, ) elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]): new_tex = self.__class__( @@ -809,6 +821,7 @@ class TexturesUV(TexturesBase): maps=[maps], padding_mode=self.padding_mode, align_corners=self.align_corners, + sampling_mode=self.sampling_mode, ) else: raise ValueError("Not all values are provided in the correct format") @@ -889,6 +902,7 @@ class TexturesUV(TexturesBase): verts_uvs=new_props["verts_uvs_padded"], padding_mode=self.padding_mode, align_corners=self.align_corners, + sampling_mode=self.sampling_mode, ) new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] @@ -966,6 +980,7 @@ class TexturesUV(TexturesBase): texels = F.grid_sample( texture_maps, pixel_uvs, + mode=self.sampling_mode, align_corners=self.align_corners, padding_mode=self.padding_mode, ) @@ -1003,6 +1018,7 @@ class TexturesUV(TexturesBase): textures = F.grid_sample( texture_maps, faces_verts_uvs, + mode=self.sampling_mode, align_corners=self.align_corners, padding_mode=self.padding_mode, ) # NxCxmax(Fi)x3 @@ -1060,6 +1076,7 @@ class TexturesUV(TexturesBase): faces_uvs=faces_uvs_list, padding_mode=self.padding_mode, align_corners=self.align_corners, + sampling_mode=self.sampling_mode, ) new_tex._num_faces_per_mesh = num_faces_per_mesh return new_tex @@ -1227,6 +1244,7 @@ class TexturesUV(TexturesBase): faces_uvs=[torch.cat(faces_uvs_merged)], align_corners=self.align_corners, padding_mode=self.padding_mode, + sampling_mode=self.sampling_mode, ) def centers_for_image(self, index: int) -> torch.Tensor: @@ -1259,6 +1277,7 @@ class TexturesUV(TexturesBase): torch.flip(coords.to(texture_image), [2]), # Convert from [0, 1] -> [-1, 1] range expected by grid sample verts_uvs[:, None] * 2.0 - 1, + mode=self.sampling_mode, align_corners=self.align_corners, padding_mode=self.padding_mode, ).cpu()