mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Adding the option to choose the texture sampling mode in TexturesUV.
Summary: This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`. This option was requested in this [GitHub issue](https://github.com/facebookresearch/pytorch3d/issues/805). Reviewed By: patricklabatut Differential Revision: D32665185 fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
This commit is contained in:
		
							parent
							
								
									e4456dba2f
								
							
						
					
					
						commit
						d9f709599b
					
				@ -596,6 +596,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
			
		||||
        padding_mode: str = "border",
 | 
			
		||||
        align_corners: bool = True,
 | 
			
		||||
        sampling_mode: str = "bilinear",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Textures are represented as a per mesh texture map and uv coordinates for each
 | 
			
		||||
@ -613,6 +614,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                            indicate the centers of the edge pixels in the maps.
 | 
			
		||||
            padding_mode: padding mode for outside grid values
 | 
			
		||||
                                ("zeros", "border" or "reflection").
 | 
			
		||||
            sampling_mode: type of interpolation used to sample the texture.
 | 
			
		||||
                            Corresponds to the mode parameter in PyTorch's
 | 
			
		||||
                            grid_sample ("nearest" or "bilinear").
 | 
			
		||||
 | 
			
		||||
        The align_corners and padding_mode arguments correspond to the arguments
 | 
			
		||||
        of the `grid_sample` torch function. There is an informative illustration of
 | 
			
		||||
@ -641,6 +645,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        """
 | 
			
		||||
        self.padding_mode = padding_mode
 | 
			
		||||
        self.align_corners = align_corners
 | 
			
		||||
        self.sampling_mode = sampling_mode
 | 
			
		||||
        if isinstance(faces_uvs, (list, tuple)):
 | 
			
		||||
            for fv in faces_uvs:
 | 
			
		||||
                if fv.ndim != 2 or fv.shape[-1] != 3:
 | 
			
		||||
@ -749,6 +754,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            self.maps_padded().clone(),
 | 
			
		||||
            self.faces_uvs_padded().clone(),
 | 
			
		||||
            self.verts_uvs_padded().clone(),
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
        if self._maps_list is not None:
 | 
			
		||||
            tex._maps_list = [m.clone() for m in self._maps_list]
 | 
			
		||||
@ -770,6 +778,9 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            self.maps_padded().detach(),
 | 
			
		||||
            self.faces_uvs_padded().detach(),
 | 
			
		||||
            self.verts_uvs_padded().detach(),
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
        if self._maps_list is not None:
 | 
			
		||||
            tex._maps_list = [m.detach() for m in self._maps_list]
 | 
			
		||||
@ -801,6 +812,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                maps=maps,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                sampling_mode=self.sampling_mode,
 | 
			
		||||
            )
 | 
			
		||||
        elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
 | 
			
		||||
            new_tex = self.__class__(
 | 
			
		||||
@ -809,6 +821,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                maps=[maps],
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                sampling_mode=self.sampling_mode,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Not all values are provided in the correct format")
 | 
			
		||||
@ -889,6 +902,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            verts_uvs=new_props["verts_uvs_padded"],
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
 | 
			
		||||
@ -966,6 +980,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        texels = F.grid_sample(
 | 
			
		||||
            texture_maps,
 | 
			
		||||
            pixel_uvs,
 | 
			
		||||
            mode=self.sampling_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
        )
 | 
			
		||||
@ -1003,6 +1018,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
        textures = F.grid_sample(
 | 
			
		||||
            texture_maps,
 | 
			
		||||
            faces_verts_uvs,
 | 
			
		||||
            mode=self.sampling_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
        )  # NxCxmax(Fi)x3
 | 
			
		||||
@ -1060,6 +1076,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            faces_uvs=faces_uvs_list,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
        new_tex._num_faces_per_mesh = num_faces_per_mesh
 | 
			
		||||
        return new_tex
 | 
			
		||||
@ -1227,6 +1244,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
            faces_uvs=[torch.cat(faces_uvs_merged)],
 | 
			
		||||
            align_corners=self.align_corners,
 | 
			
		||||
            padding_mode=self.padding_mode,
 | 
			
		||||
            sampling_mode=self.sampling_mode,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def centers_for_image(self, index: int) -> torch.Tensor:
 | 
			
		||||
@ -1259,6 +1277,7 @@ class TexturesUV(TexturesBase):
 | 
			
		||||
                torch.flip(coords.to(texture_image), [2]),
 | 
			
		||||
                # Convert from [0, 1] -> [-1, 1] range expected by grid sample
 | 
			
		||||
                verts_uvs[:, None] * 2.0 - 1,
 | 
			
		||||
                mode=self.sampling_mode,
 | 
			
		||||
                align_corners=self.align_corners,
 | 
			
		||||
                padding_mode=self.padding_mode,
 | 
			
		||||
            ).cpu()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user