mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Adding the option to choose the texture sampling mode in TexturesUV.
Summary: This diff adds the `sample_mode` parameter to `TexturesUV` to control the interpolation mode during texture sampling. It simply gets forwarded to `torch.nn.funcitonal.grid_sample`. This option was requested in this [GitHub issue](https://github.com/facebookresearch/pytorch3d/issues/805). Reviewed By: patricklabatut Differential Revision: D32665185 fbshipit-source-id: ac0bc66a018bd4cb20d75fec2d7c11145dd20199
This commit is contained in:
parent
e4456dba2f
commit
d9f709599b
@ -596,6 +596,7 @@ class TexturesUV(TexturesBase):
|
|||||||
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
|
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
|
||||||
padding_mode: str = "border",
|
padding_mode: str = "border",
|
||||||
align_corners: bool = True,
|
align_corners: bool = True,
|
||||||
|
sampling_mode: str = "bilinear",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Textures are represented as a per mesh texture map and uv coordinates for each
|
Textures are represented as a per mesh texture map and uv coordinates for each
|
||||||
@ -613,6 +614,9 @@ class TexturesUV(TexturesBase):
|
|||||||
indicate the centers of the edge pixels in the maps.
|
indicate the centers of the edge pixels in the maps.
|
||||||
padding_mode: padding mode for outside grid values
|
padding_mode: padding mode for outside grid values
|
||||||
("zeros", "border" or "reflection").
|
("zeros", "border" or "reflection").
|
||||||
|
sampling_mode: type of interpolation used to sample the texture.
|
||||||
|
Corresponds to the mode parameter in PyTorch's
|
||||||
|
grid_sample ("nearest" or "bilinear").
|
||||||
|
|
||||||
The align_corners and padding_mode arguments correspond to the arguments
|
The align_corners and padding_mode arguments correspond to the arguments
|
||||||
of the `grid_sample` torch function. There is an informative illustration of
|
of the `grid_sample` torch function. There is an informative illustration of
|
||||||
@ -641,6 +645,7 @@ class TexturesUV(TexturesBase):
|
|||||||
"""
|
"""
|
||||||
self.padding_mode = padding_mode
|
self.padding_mode = padding_mode
|
||||||
self.align_corners = align_corners
|
self.align_corners = align_corners
|
||||||
|
self.sampling_mode = sampling_mode
|
||||||
if isinstance(faces_uvs, (list, tuple)):
|
if isinstance(faces_uvs, (list, tuple)):
|
||||||
for fv in faces_uvs:
|
for fv in faces_uvs:
|
||||||
if fv.ndim != 2 or fv.shape[-1] != 3:
|
if fv.ndim != 2 or fv.shape[-1] != 3:
|
||||||
@ -749,6 +754,9 @@ class TexturesUV(TexturesBase):
|
|||||||
self.maps_padded().clone(),
|
self.maps_padded().clone(),
|
||||||
self.faces_uvs_padded().clone(),
|
self.faces_uvs_padded().clone(),
|
||||||
self.verts_uvs_padded().clone(),
|
self.verts_uvs_padded().clone(),
|
||||||
|
align_corners=self.align_corners,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
if self._maps_list is not None:
|
if self._maps_list is not None:
|
||||||
tex._maps_list = [m.clone() for m in self._maps_list]
|
tex._maps_list = [m.clone() for m in self._maps_list]
|
||||||
@ -770,6 +778,9 @@ class TexturesUV(TexturesBase):
|
|||||||
self.maps_padded().detach(),
|
self.maps_padded().detach(),
|
||||||
self.faces_uvs_padded().detach(),
|
self.faces_uvs_padded().detach(),
|
||||||
self.verts_uvs_padded().detach(),
|
self.verts_uvs_padded().detach(),
|
||||||
|
align_corners=self.align_corners,
|
||||||
|
padding_mode=self.padding_mode,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
if self._maps_list is not None:
|
if self._maps_list is not None:
|
||||||
tex._maps_list = [m.detach() for m in self._maps_list]
|
tex._maps_list = [m.detach() for m in self._maps_list]
|
||||||
@ -801,6 +812,7 @@ class TexturesUV(TexturesBase):
|
|||||||
maps=maps,
|
maps=maps,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
|
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
|
||||||
new_tex = self.__class__(
|
new_tex = self.__class__(
|
||||||
@ -809,6 +821,7 @@ class TexturesUV(TexturesBase):
|
|||||||
maps=[maps],
|
maps=[maps],
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not all values are provided in the correct format")
|
raise ValueError("Not all values are provided in the correct format")
|
||||||
@ -889,6 +902,7 @@ class TexturesUV(TexturesBase):
|
|||||||
verts_uvs=new_props["verts_uvs_padded"],
|
verts_uvs=new_props["verts_uvs_padded"],
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||||
@ -966,6 +980,7 @@ class TexturesUV(TexturesBase):
|
|||||||
texels = F.grid_sample(
|
texels = F.grid_sample(
|
||||||
texture_maps,
|
texture_maps,
|
||||||
pixel_uvs,
|
pixel_uvs,
|
||||||
|
mode=self.sampling_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
@ -1003,6 +1018,7 @@ class TexturesUV(TexturesBase):
|
|||||||
textures = F.grid_sample(
|
textures = F.grid_sample(
|
||||||
texture_maps,
|
texture_maps,
|
||||||
faces_verts_uvs,
|
faces_verts_uvs,
|
||||||
|
mode=self.sampling_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
) # NxCxmax(Fi)x3
|
) # NxCxmax(Fi)x3
|
||||||
@ -1060,6 +1076,7 @@ class TexturesUV(TexturesBase):
|
|||||||
faces_uvs=faces_uvs_list,
|
faces_uvs=faces_uvs_list,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
||||||
return new_tex
|
return new_tex
|
||||||
@ -1227,6 +1244,7 @@ class TexturesUV(TexturesBase):
|
|||||||
faces_uvs=[torch.cat(faces_uvs_merged)],
|
faces_uvs=[torch.cat(faces_uvs_merged)],
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
|
sampling_mode=self.sampling_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def centers_for_image(self, index: int) -> torch.Tensor:
|
def centers_for_image(self, index: int) -> torch.Tensor:
|
||||||
@ -1259,6 +1277,7 @@ class TexturesUV(TexturesBase):
|
|||||||
torch.flip(coords.to(texture_image), [2]),
|
torch.flip(coords.to(texture_image), [2]),
|
||||||
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
|
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
|
||||||
verts_uvs[:, None] * 2.0 - 1,
|
verts_uvs[:, None] * 2.0 - 1,
|
||||||
|
mode=self.sampling_mode,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
).cpu()
|
).cpu()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user