mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
align_corners and padding for TexturesUV
Summary: Allow, and make default, align_corners=True for texture maps. Allow changing the padding_mode and set the default to be "border" which produces more logical results. Some new documentation. The previous behavior corresponds to padding_mode="zeros" and align_corners=False. Reviewed By: gkioxari Differential Revision: D23268775 fbshipit-source-id: 58d6229baa591baa69705bcf97471c80ba3651de
This commit is contained in:
parent
d0cec028c7
commit
e25ccab3d9
@ -98,13 +98,14 @@ def _padded_to_list_wrapper(
|
||||
|
||||
|
||||
def _pad_texture_maps(
|
||||
images: Union[Tuple[torch.Tensor], List[torch.Tensor]]
|
||||
images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pad all texture images so they have the same height and width.
|
||||
|
||||
Args:
|
||||
images: list of N tensors of shape (H, W, 3)
|
||||
images: list of N tensors of shape (H_i, W_i, 3)
|
||||
align_corners: used for interpolation
|
||||
|
||||
Returns:
|
||||
tex_maps: Tensor of shape (N, max_H, max_W, 3)
|
||||
@ -125,7 +126,7 @@ def _pad_texture_maps(
|
||||
if image.shape[:2] != max_shape:
|
||||
image_BCHW = image.permute(2, 0, 1)[None]
|
||||
new_image_BCHW = interpolate(
|
||||
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
|
||||
image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners
|
||||
)
|
||||
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
|
||||
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
|
||||
@ -535,6 +536,8 @@ class TexturesUV(TexturesBase):
|
||||
maps: Union[torch.Tensor, List[torch.Tensor]],
|
||||
faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
|
||||
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
|
||||
padding_mode: str = "border",
|
||||
align_corners: bool = True,
|
||||
):
|
||||
"""
|
||||
Textures are represented as a per mesh texture map and uv coordinates for each
|
||||
@ -543,11 +546,42 @@ class TexturesUV(TexturesBase):
|
||||
Args:
|
||||
maps: texture map per mesh. This can either be a list of maps
|
||||
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
|
||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
|
||||
faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
|
||||
for each face
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
|
||||
(a FloatTensor with values between 0 and 1)
|
||||
(a FloatTensor with values between 0 and 1).
|
||||
align_corners: If true, the extreme values 0 and 1 for verts_uvs
|
||||
indicate the centers of the edge pixels in the maps.
|
||||
padding_mode: padding mode for outside grid values
|
||||
("zeros", "border" or "reflection").
|
||||
|
||||
The align_corners and padding_mode arguments correspond to the arguments
|
||||
of the `grid_sample` torch function. There is an informative illustration of
|
||||
the two align_corners options at
|
||||
https://discuss.pytorch.org/t/22663/9 .
|
||||
|
||||
An example of how the indexing into the maps, with align_corners=True
|
||||
works is as follows.
|
||||
If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
|
||||
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
|
||||
whose color is given by maps[i][700, 40]. padding_mode affects what
|
||||
happens if a value in verts_uvs is less than 0 or greater than 1.
|
||||
Note that increasing a value in verts_uvs[..., 0] increases an index
|
||||
in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_
|
||||
an _earlier_ index in maps.
|
||||
|
||||
If align_corners=False, an example would be as follows.
|
||||
If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
|
||||
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
|
||||
whose color is given by maps[i][700, 40].
|
||||
In this case, padding_mode even matters for values in verts_uvs
|
||||
slightly above 0 or slightly below 1. In this case, it matters if the
|
||||
first value is outside the interval [0.0005, 0.9995] or if the second
|
||||
is outside the interval [0.005, 0.995].
|
||||
"""
|
||||
super().__init__()
|
||||
self.padding_mode = padding_mode
|
||||
self.align_corners = align_corners
|
||||
if isinstance(faces_uvs, (list, tuple)):
|
||||
for fv in faces_uvs:
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
|
||||
@ -632,7 +666,7 @@ class TexturesUV(TexturesBase):
|
||||
raise ValueError("Expected one texture map per mesh in the batch.")
|
||||
self._maps_list = maps
|
||||
if self._N > 0:
|
||||
maps = _pad_texture_maps(maps)
|
||||
maps = _pad_texture_maps(maps, align_corners=self.align_corners)
|
||||
else:
|
||||
maps = torch.empty(
|
||||
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
|
||||
@ -698,11 +732,19 @@ class TexturesUV(TexturesBase):
|
||||
# if index has multiple values then faces/verts/maps may be a list of tensors
|
||||
if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
|
||||
new_tex = self.__class__(
|
||||
faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps
|
||||
faces_uvs=faces_uvs,
|
||||
verts_uvs=verts_uvs,
|
||||
maps=maps,
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
|
||||
new_tex = self.__class__(
|
||||
faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps]
|
||||
faces_uvs=[faces_uvs],
|
||||
verts_uvs=[verts_uvs],
|
||||
maps=[maps],
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Not all values are provided in the correct format")
|
||||
@ -785,6 +827,8 @@ class TexturesUV(TexturesBase):
|
||||
maps=new_props["maps_padded"],
|
||||
faces_uvs=new_props["faces_uvs_padded"],
|
||||
verts_uvs=new_props["verts_uvs_padded"],
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
|
||||
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
|
||||
@ -859,7 +903,12 @@ class TexturesUV(TexturesBase):
|
||||
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
|
||||
if texture_maps.device != pixel_uvs.device:
|
||||
texture_maps = texture_maps.to(pixel_uvs.device)
|
||||
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
|
||||
texels = F.grid_sample(
|
||||
texture_maps,
|
||||
pixel_uvs,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
# texels now has shape (NK, C, H_out, W_out)
|
||||
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
|
||||
return texels
|
||||
@ -881,6 +930,17 @@ class TexturesUV(TexturesBase):
|
||||
if not tex_types_same:
|
||||
raise ValueError("All textures must be of type TexturesUV.")
|
||||
|
||||
padding_modes_same = all(
|
||||
tex.padding_mode == self.padding_mode for tex in textures
|
||||
)
|
||||
if not padding_modes_same:
|
||||
raise ValueError("All textures must have the same padding_mode.")
|
||||
align_corners_same = all(
|
||||
tex.align_corners == self.align_corners for tex in textures
|
||||
)
|
||||
if not align_corners_same:
|
||||
raise ValueError("All textures must have the same align_corners value.")
|
||||
|
||||
verts_uvs_list = []
|
||||
faces_uvs_list = []
|
||||
maps_list = []
|
||||
@ -896,7 +956,11 @@ class TexturesUV(TexturesBase):
|
||||
maps_list += tex_map_list
|
||||
|
||||
new_tex = self.__class__(
|
||||
maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list
|
||||
maps=maps_list,
|
||||
verts_uvs=verts_uvs_list,
|
||||
faces_uvs=faces_uvs_list,
|
||||
padding_mode=self.padding_mode,
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
new_tex._num_faces_per_mesh = num_faces_per_mesh
|
||||
return new_tex
|
||||
|
@ -1505,9 +1505,13 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
|
||||
must all be on the same device. If include_textures is true, they must all
|
||||
be compatible, either all or none having textures, and all the Textures
|
||||
objects having the same members. If include_textures is False, textures are
|
||||
objects being the same type. If include_textures is False, textures are
|
||||
ignored.
|
||||
|
||||
If the textures are TexturesAtlas then being the same type includes having
|
||||
the same resolution. If they are TexturesUV then it includes having the same
|
||||
align_corners and padding_mode.
|
||||
|
||||
Args:
|
||||
meshes: list of meshes.
|
||||
include_textures: (bool) whether to try to join the textures.
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB |
Binary file not shown.
Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 31 KiB |
Binary file not shown.
Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 30 KiB |
@ -373,10 +373,10 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self.assertFalse(new_mesh.verts_packed().requires_grad)
|
||||
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
|
||||
self.assertTrue(new_mesh.verts_padded().requires_grad == False)
|
||||
self.assertFalse(new_mesh.verts_padded().requires_grad)
|
||||
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
|
||||
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
|
||||
self.assertTrue(newv.requires_grad == False)
|
||||
self.assertFalse(newv.requires_grad)
|
||||
self.assertClose(newv, v)
|
||||
|
||||
def test_laplacian_packed(self):
|
||||
|
@ -411,11 +411,11 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
new_clouds = clouds.detach()
|
||||
|
||||
for cloud in new_clouds.points_list():
|
||||
self.assertTrue(cloud.requires_grad == False)
|
||||
self.assertFalse(cloud.requires_grad)
|
||||
for normal in new_clouds.normals_list():
|
||||
self.assertTrue(normal.requires_grad == False)
|
||||
self.assertFalse(normal.requires_grad)
|
||||
for feats in new_clouds.features_list():
|
||||
self.assertTrue(feats.requires_grad == False)
|
||||
self.assertFalse(feats.requires_grad)
|
||||
|
||||
for attrib in [
|
||||
"points_packed",
|
||||
@ -425,9 +425,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
"normals_padded",
|
||||
"features_padded",
|
||||
]:
|
||||
self.assertTrue(
|
||||
getattr(new_clouds, attrib)().requires_grad == False
|
||||
)
|
||||
self.assertFalse(getattr(new_clouds, attrib)().requires_grad)
|
||||
|
||||
self.assertCloudsEqual(clouds, new_clouds)
|
||||
|
||||
|
@ -443,7 +443,13 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
dists=pix_to_face,
|
||||
)
|
||||
|
||||
tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
|
||||
for align_corners in [True, False]:
|
||||
tex = TexturesUV(
|
||||
maps=tex_map,
|
||||
faces_uvs=[face_uvs],
|
||||
verts_uvs=[vert_uvs],
|
||||
align_corners=align_corners,
|
||||
)
|
||||
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
|
||||
mesh_textures = meshes.textures
|
||||
texels = mesh_textures.sample_textures(fragments)
|
||||
@ -451,10 +457,11 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
# Expected output
|
||||
pixel_uvs = interpolated_uvs * 2.0 - 1.0
|
||||
pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
|
||||
tex_map = torch.flip(tex_map, [1])
|
||||
tex_map = tex_map.permute(0, 3, 1, 2)
|
||||
tex_map = torch.cat([tex_map, tex_map], dim=0)
|
||||
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
|
||||
tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2)
|
||||
tex_map_ = torch.cat([tex_map_, tex_map_], dim=0)
|
||||
expected_out = F.grid_sample(
|
||||
tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border"
|
||||
)
|
||||
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
|
||||
|
||||
def test_textures_uv_init_fail(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user