mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Adding the option to choose the texture sampling mode in TexturesUV.
Summary: This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`. This option was requested in this [GitHub issue](https://github.com/facebookresearch/pytorch3d/issues/805). Reviewed By: patricklabatut Differential Revision: D32665185 fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
This commit is contained in:
parent
e4456dba2f
commit
d9f709599b
@ -596,6 +596,7 @@ class TexturesUV(TexturesBase):
|
||||
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
|
||||
padding_mode: str = "border",
|
||||
align_corners: bool = True,
|
||||
sampling_mode: str = "bilinear",
|
||||
) -> None:
|
||||
"""
|
||||
Textures are represented as a per mesh texture map and uv coordinates for each
|
||||
@ -613,6 +614,9 @@ class TexturesUV(TexturesBase):
|
||||
indicate the centers of the edge pixels in the maps.
|
||||
padding_mode: padding mode for outside grid values
|
||||
("zeros", "border" or "reflection").
|
||||
sampling_mode: type of interpolation used to sample the texture.
|
||||
Corresponds to the mode parameter in PyTorch's
|
||||
grid_sample ("nearest" or "bilinear").
|
||||
|
||||
The align_corners and padding_mode arguments correspond to the arguments
|
||||
of the `grid_sample` torch function. There is an informative illustration of
|
||||
@ -641,6 +645,7 @@ class TexturesUV(TexturesBase):
|
||||
"""
|
||||
self.padding_mode = padding_mode
|
||||
self.align_corners = align_corners
|
||||
self.sampling_mode = sampling_mode
|
||||
if isinstance(faces_uvs, (list, tuple)):
|
||||
for fv in faces_uvs:
|
||||
if fv.ndim != 2 or fv.shape[-1] != 3:
|
||||
@ -749,6 +754,9 @@ class TexturesUV(TexturesBase):
|
||||
self.maps_padded().clone(),
|
||||
self.faces_uvs_padded().clone(),
|
||||
self.verts_uvs_padded().clone(),
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
if self._maps_list is not None:
|
||||
tex._maps_list = [m.clone() for m in self._maps_list]
|
||||
@ -770,6 +778,9 @@ class TexturesUV(TexturesBase):
|
||||
self.maps_padded().detach(),
|
||||
self.faces_uvs_padded().detach(),
|
||||
self.verts_uvs_padded().detach(),
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
if self._maps_list is not None:
|
||||
tex._maps_list = [m.detach() for m in self._maps_list]
|
||||
@ -801,6 +812,7 @@ class TexturesUV(TexturesBase):
|
||||
maps=maps,
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
|
||||
new_tex = self.__class__(
|
||||
@ -809,6 +821,7 @@ class TexturesUV(TexturesBase):
|
||||
maps=[maps],
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not all values are provided in the correct format")
|
||||
@ -889,6 +902,7 @@ class TexturesUV(TexturesBase):
|
||||
verts_uvs=new_props["verts_uvs_padded"],
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
|
||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||
@ -966,6 +980,7 @@ class TexturesUV(TexturesBase):
|
||||
texels = F.grid_sample(
|
||||
texture_maps,
|
||||
pixel_uvs,
|
||||
mode=self.sampling_mode,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
@ -1003,6 +1018,7 @@ class TexturesUV(TexturesBase):
|
||||
textures = F.grid_sample(
|
||||
texture_maps,
|
||||
faces_verts_uvs,
|
||||
mode=self.sampling_mode,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
) # NxCxmax(Fi)x3
|
||||
@ -1060,6 +1076,7 @@ class TexturesUV(TexturesBase):
|
||||
faces_uvs=faces_uvs_list,
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
||||
return new_tex
|
||||
@ -1227,6 +1244,7 @@ class TexturesUV(TexturesBase):
|
||||
faces_uvs=[torch.cat(faces_uvs_merged)],
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
sampling_mode=self.sampling_mode,
|
||||
)
|
||||
|
||||
def centers_for_image(self, index: int) -> torch.Tensor:
|
||||
@ -1259,6 +1277,7 @@ class TexturesUV(TexturesBase):
|
||||
torch.flip(coords.to(texture_image), [2]),
|
||||
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
|
||||
verts_uvs[:, None] * 2.0 - 1,
|
||||
mode=self.sampling_mode,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
).cpu()
|
||||
|
Loading…
x
Reference in New Issue
Block a user