Adding the option to choose the texture sampling mode in TexturesUV.

Summary:
This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`.

This option was requested in this [GitHub issue](https://github.com/facebookresearch/pytorch3d/issues/805).

Reviewed By: patricklabatut

Differential Revision: D32665185

fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
This commit is contained in:
Ana Dodik 2021-11-29 07:00:13 -08:00 committed by Facebook GitHub Bot
parent e4456dba2f
commit d9f709599b

View File

@ -596,6 +596,7 @@ class TexturesUV(TexturesBase):
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border",
align_corners: bool = True,
sampling_mode: str = "bilinear",
) -> None:
"""
Textures are represented as a per mesh texture map and uv coordinates for each
@ -613,6 +614,9 @@ class TexturesUV(TexturesBase):
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
sampling_mode: type of interpolation used to sample the texture.
Corresponds to the mode parameter in PyTorch's
grid_sample ("nearest" or "bilinear").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
@ -641,6 +645,7 @@ class TexturesUV(TexturesBase):
"""
self.padding_mode = padding_mode
self.align_corners = align_corners
self.sampling_mode = sampling_mode
if isinstance(faces_uvs, (list, tuple)):
for fv in faces_uvs:
if fv.ndim != 2 or fv.shape[-1] != 3:
@ -749,6 +754,9 @@ class TexturesUV(TexturesBase):
self.maps_padded().clone(),
self.faces_uvs_padded().clone(),
self.verts_uvs_padded().clone(),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.clone() for m in self._maps_list]
@ -770,6 +778,9 @@ class TexturesUV(TexturesBase):
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
self.verts_uvs_padded().detach(),
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
if self._maps_list is not None:
tex._maps_list = [m.detach() for m in self._maps_list]
@ -801,6 +812,7 @@ class TexturesUV(TexturesBase):
maps=maps,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__(
@ -809,6 +821,7 @@ class TexturesUV(TexturesBase):
maps=[maps],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
else:
raise ValueError("Not all values are provided in the correct format")
@ -889,6 +902,7 @@ class TexturesUV(TexturesBase):
verts_uvs=new_props["verts_uvs_padded"],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
@ -966,6 +980,7 @@ class TexturesUV(TexturesBase):
texels = F.grid_sample(
texture_maps,
pixel_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
@ -1003,6 +1018,7 @@ class TexturesUV(TexturesBase):
textures = F.grid_sample(
texture_maps,
faces_verts_uvs,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
) # NxCxmax(Fi)x3
@ -1060,6 +1076,7 @@ class TexturesUV(TexturesBase):
faces_uvs=faces_uvs_list,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
sampling_mode=self.sampling_mode,
)
new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex
@ -1227,6 +1244,7 @@ class TexturesUV(TexturesBase):
faces_uvs=[torch.cat(faces_uvs_merged)],
align_corners=self.align_corners,
padding_mode=self.padding_mode,
sampling_mode=self.sampling_mode,
)
def centers_for_image(self, index: int) -> torch.Tensor:
@ -1259,6 +1277,7 @@ class TexturesUV(TexturesBase):
torch.flip(coords.to(texture_image), [2]),
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
verts_uvs[:, None] * 2.0 - 1,
mode=self.sampling_mode,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
).cpu()