align_corners and padding for TexturesUV

Summary:
Allow, and make default, align_corners=True for texture maps. Allow changing the padding_mode and set the default to be "border" which produces more logical results. Some new documentation.

The previous behavior corresponds to padding_mode="zeros" and align_corners=False.

Reviewed By: gkioxari

Differential Revision: D23268775

fbshipit-source-id: 58d6229baa591baa69705bcf97471c80ba3651de
This commit is contained in:
Jeremy Reizenstein 2020-08-25 11:26:58 -07:00 committed by Facebook GitHub Bot
parent d0cec028c7
commit e25ccab3d9
8 changed files with 104 additions and 31 deletions

View File

@ -98,13 +98,14 @@ def _padded_to_list_wrapper(
def _pad_texture_maps( def _pad_texture_maps(
images: Union[Tuple[torch.Tensor], List[torch.Tensor]] images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Pad all texture images so they have the same height and width. Pad all texture images so they have the same height and width.
Args: Args:
images: list of N tensors of shape (H, W, 3) images: list of N tensors of shape (H_i, W_i, 3)
align_corners: used for interpolation
Returns: Returns:
tex_maps: Tensor of shape (N, max_H, max_W, 3) tex_maps: Tensor of shape (N, max_H, max_W, 3)
@ -125,7 +126,7 @@ def _pad_texture_maps(
if image.shape[:2] != max_shape: if image.shape[:2] != max_shape:
image_BCHW = image.permute(2, 0, 1)[None] image_BCHW = image.permute(2, 0, 1)[None]
new_image_BCHW = interpolate( new_image_BCHW = interpolate(
image_BCHW, size=max_shape, mode="bilinear", align_corners=False image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners
) )
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
@ -535,6 +536,8 @@ class TexturesUV(TexturesBase):
maps: Union[torch.Tensor, List[torch.Tensor]], maps: Union[torch.Tensor, List[torch.Tensor]],
faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border",
align_corners: bool = True,
): ):
""" """
Textures are represented as a per mesh texture map and uv coordinates for each Textures are represented as a per mesh texture map and uv coordinates for each
@ -543,11 +546,42 @@ class TexturesUV(TexturesBase):
Args: Args:
maps: texture map per mesh. This can either be a list of maps maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3) [(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
(a FloatTensor with values between 0 and 1) (a FloatTensor with values between 0 and 1).
align_corners: If true, the extreme values 0 and 1 for verts_uvs
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
the two align_corners options at
https://discuss.pytorch.org/t/22663/9 .
An example of how the indexing into the maps, with align_corners=True
works is as follows.
If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40]. padding_mode affects what
happens if a value in verts_uvs is less than 0 or greater than 1.
Note that increasing a value in verts_uvs[..., 0] increases an index
in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_
an _earlier_ index in maps.
If align_corners=False, an example would be as follows.
If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40].
In this case, padding_mode even matters for values in verts_uvs
slightly above 0 or slightly below 1. In this case, it matters if the
first value is outside the interval [0.0005, 0.9995] or if the second
is outside the interval [0.005, 0.995].
""" """
super().__init__() super().__init__()
self.padding_mode = padding_mode
self.align_corners = align_corners
if isinstance(faces_uvs, (list, tuple)): if isinstance(faces_uvs, (list, tuple)):
for fv in faces_uvs: for fv in faces_uvs:
# pyre-fixme[16]: `Tensor` has no attribute `ndim`. # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
@ -632,7 +666,7 @@ class TexturesUV(TexturesBase):
raise ValueError("Expected one texture map per mesh in the batch.") raise ValueError("Expected one texture map per mesh in the batch.")
self._maps_list = maps self._maps_list = maps
if self._N > 0: if self._N > 0:
maps = _pad_texture_maps(maps) maps = _pad_texture_maps(maps, align_corners=self.align_corners)
else: else:
maps = torch.empty( maps = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device (self._N, 0, 0, 3), dtype=torch.float32, device=self.device
@ -698,11 +732,19 @@ class TexturesUV(TexturesBase):
# if index has multiple values then faces/verts/maps may be a list of tensors # if index has multiple values then faces/verts/maps may be a list of tensors
if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]): if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__( new_tex = self.__class__(
faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps faces_uvs=faces_uvs,
verts_uvs=verts_uvs,
maps=maps,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
) )
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]): elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__( new_tex = self.__class__(
faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps] faces_uvs=[faces_uvs],
verts_uvs=[verts_uvs],
maps=[maps],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
) )
else: else:
raise ValueError("Not all values are provided in the correct format") raise ValueError("Not all values are provided in the correct format")
@ -785,6 +827,8 @@ class TexturesUV(TexturesBase):
maps=new_props["maps_padded"], maps=new_props["maps_padded"],
faces_uvs=new_props["faces_uvs_padded"], faces_uvs=new_props["faces_uvs_padded"],
verts_uvs=new_props["verts_uvs_padded"], verts_uvs=new_props["verts_uvs_padded"],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
) )
new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
@ -859,7 +903,12 @@ class TexturesUV(TexturesBase):
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device: if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device) texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) texels = F.grid_sample(
texture_maps,
pixel_uvs,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
# texels now has shape (NK, C, H_out, W_out) # texels now has shape (NK, C, H_out, W_out)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels return texels
@ -881,6 +930,17 @@ class TexturesUV(TexturesBase):
if not tex_types_same: if not tex_types_same:
raise ValueError("All textures must be of type TexturesUV.") raise ValueError("All textures must be of type TexturesUV.")
padding_modes_same = all(
tex.padding_mode == self.padding_mode for tex in textures
)
if not padding_modes_same:
raise ValueError("All textures must have the same padding_mode.")
align_corners_same = all(
tex.align_corners == self.align_corners for tex in textures
)
if not align_corners_same:
raise ValueError("All textures must have the same align_corners value.")
verts_uvs_list = [] verts_uvs_list = []
faces_uvs_list = [] faces_uvs_list = []
maps_list = [] maps_list = []
@ -896,7 +956,11 @@ class TexturesUV(TexturesBase):
maps_list += tex_map_list maps_list += tex_map_list
new_tex = self.__class__( new_tex = self.__class__(
maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list maps=maps_list,
verts_uvs=verts_uvs_list,
faces_uvs=faces_uvs_list,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
) )
new_tex._num_faces_per_mesh = num_faces_per_mesh new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex return new_tex

View File

@ -1505,9 +1505,13 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all must all be on the same device. If include_textures is true, they must all
be compatible, either all or none having textures, and all the Textures be compatible, either all or none having textures, and all the Textures
objects having the same members. If include_textures is False, textures are objects being the same type. If include_textures is False, textures are
ignored. ignored.
If the textures are TexturesAtlas then being the same type includes having
the same resolution. If they are TexturesUV then it includes having the same
align_corners and padding_mode.
Args: Args:
meshes: list of meshes. meshes: list of meshes.
include_textures: (bool) whether to try to join the textures. include_textures: (bool) whether to try to join the textures.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 43 KiB

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 30 KiB

View File

@ -373,10 +373,10 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertFalse(new_mesh.verts_packed().requires_grad) self.assertFalse(new_mesh.verts_packed().requires_grad)
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed()) self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
self.assertTrue(new_mesh.verts_padded().requires_grad == False) self.assertFalse(new_mesh.verts_padded().requires_grad)
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded()) self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()): for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
self.assertTrue(newv.requires_grad == False) self.assertFalse(newv.requires_grad)
self.assertClose(newv, v) self.assertClose(newv, v)
def test_laplacian_packed(self): def test_laplacian_packed(self):

View File

@ -411,11 +411,11 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
new_clouds = clouds.detach() new_clouds = clouds.detach()
for cloud in new_clouds.points_list(): for cloud in new_clouds.points_list():
self.assertTrue(cloud.requires_grad == False) self.assertFalse(cloud.requires_grad)
for normal in new_clouds.normals_list(): for normal in new_clouds.normals_list():
self.assertTrue(normal.requires_grad == False) self.assertFalse(normal.requires_grad)
for feats in new_clouds.features_list(): for feats in new_clouds.features_list():
self.assertTrue(feats.requires_grad == False) self.assertFalse(feats.requires_grad)
for attrib in [ for attrib in [
"points_packed", "points_packed",
@ -425,9 +425,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
"normals_padded", "normals_padded",
"features_padded", "features_padded",
]: ]:
self.assertTrue( self.assertFalse(getattr(new_clouds, attrib)().requires_grad)
getattr(new_clouds, attrib)().requires_grad == False
)
self.assertCloudsEqual(clouds, new_clouds) self.assertCloudsEqual(clouds, new_clouds)

View File

@ -443,19 +443,26 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
dists=pix_to_face, dists=pix_to_face,
) )
tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs]) for align_corners in [True, False]:
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex) tex = TexturesUV(
mesh_textures = meshes.textures maps=tex_map,
texels = mesh_textures.sample_textures(fragments) faces_uvs=[face_uvs],
verts_uvs=[vert_uvs],
align_corners=align_corners,
)
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)
# Expected output # Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0 pixel_uvs = interpolated_uvs * 2.0 - 1.0
pixel_uvs = pixel_uvs.view(2, 1, 1, 2) pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
tex_map = torch.flip(tex_map, [1]) tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2)
tex_map = tex_map.permute(0, 3, 1, 2) tex_map_ = torch.cat([tex_map_, tex_map_], dim=0)
tex_map = torch.cat([tex_map, tex_map], dim=0) expected_out = F.grid_sample(
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False) tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border"
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze())) )
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
def test_textures_uv_init_fail(self): def test_textures_uv_init_fail(self):
# Maps has wrong shape # Maps has wrong shape