mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	align_corners and padding for TexturesUV
Summary: Allow, and make default, align_corners=True for texture maps. Allow changing the padding_mode and set the default to be "border" which produces more logical results. Some new documentation. The previous behavior corresponds to padding_mode="zeros" and align_corners=False. Reviewed By: gkioxari Differential Revision: D23268775 fbshipit-source-id: 58d6229baa591baa69705bcf97471c80ba3651de
This commit is contained in:
		
							parent
							
								
									d0cec028c7
								
							
						
					
					
						commit
						e25ccab3d9
					
				@ -98,13 +98,14 @@ def _padded_to_list_wrapper(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _pad_texture_maps(
 | 
					def _pad_texture_maps(
 | 
				
			||||||
    images: Union[Tuple[torch.Tensor], List[torch.Tensor]]
 | 
					    images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Pad all texture images so they have the same height and width.
 | 
					    Pad all texture images so they have the same height and width.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        images: list of N tensors of shape (H, W, 3)
 | 
					        images: list of N tensors of shape (H_i, W_i, 3)
 | 
				
			||||||
 | 
					        align_corners: used for interpolation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Returns:
 | 
					    Returns:
 | 
				
			||||||
        tex_maps: Tensor of shape (N, max_H, max_W, 3)
 | 
					        tex_maps: Tensor of shape (N, max_H, max_W, 3)
 | 
				
			||||||
@ -125,7 +126,7 @@ def _pad_texture_maps(
 | 
				
			|||||||
        if image.shape[:2] != max_shape:
 | 
					        if image.shape[:2] != max_shape:
 | 
				
			||||||
            image_BCHW = image.permute(2, 0, 1)[None]
 | 
					            image_BCHW = image.permute(2, 0, 1)[None]
 | 
				
			||||||
            new_image_BCHW = interpolate(
 | 
					            new_image_BCHW = interpolate(
 | 
				
			||||||
                image_BCHW, size=max_shape, mode="bilinear", align_corners=False
 | 
					                image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
 | 
					            tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
 | 
				
			||||||
    tex_maps = torch.stack(tex_maps, dim=0)  # (num_tex_maps, max_H, max_W, 3)
 | 
					    tex_maps = torch.stack(tex_maps, dim=0)  # (num_tex_maps, max_H, max_W, 3)
 | 
				
			||||||
@ -535,6 +536,8 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
        maps: Union[torch.Tensor, List[torch.Tensor]],
 | 
					        maps: Union[torch.Tensor, List[torch.Tensor]],
 | 
				
			||||||
        faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
					        faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
				
			||||||
        verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
					        verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
 | 
				
			||||||
 | 
					        padding_mode: str = "border",
 | 
				
			||||||
 | 
					        align_corners: bool = True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Textures are represented as a per mesh texture map and uv coordinates for each
 | 
					        Textures are represented as a per mesh texture map and uv coordinates for each
 | 
				
			||||||
@ -543,11 +546,42 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            maps: texture map per mesh. This can either be a list of maps
 | 
					            maps: texture map per mesh. This can either be a list of maps
 | 
				
			||||||
              [(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
 | 
					              [(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
 | 
				
			||||||
            faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
 | 
					            faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
 | 
				
			||||||
 | 
					                        for each face
 | 
				
			||||||
            verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
 | 
					            verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
 | 
				
			||||||
                         (a FloatTensor with values between 0 and 1)
 | 
					                        (a FloatTensor with values between 0 and 1).
 | 
				
			||||||
 | 
					            align_corners: If true, the extreme values 0 and 1 for verts_uvs
 | 
				
			||||||
 | 
					                            indicate the centers of the edge pixels in the maps.
 | 
				
			||||||
 | 
					            padding_mode: padding mode for outside grid values
 | 
				
			||||||
 | 
					                                ("zeros", "border" or "reflection").
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        The align_corners and padding_mode arguments correspond to the arguments
 | 
				
			||||||
 | 
					        of the `grid_sample` torch function. There is an informative illustration of
 | 
				
			||||||
 | 
					        the two align_corners options at
 | 
				
			||||||
 | 
					        https://discuss.pytorch.org/t/22663/9 .
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        An example of how the indexing into the maps, with align_corners=True
 | 
				
			||||||
 | 
					        works is as follows.
 | 
				
			||||||
 | 
					        If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
 | 
				
			||||||
 | 
					        is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
 | 
				
			||||||
 | 
					        whose color is given by maps[i][700, 40]. padding_mode affects what
 | 
				
			||||||
 | 
					        happens if a value in verts_uvs is less than 0 or greater than 1.
 | 
				
			||||||
 | 
					        Note that increasing a value in verts_uvs[..., 0] increases an index
 | 
				
			||||||
 | 
					        in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_
 | 
				
			||||||
 | 
					        an _earlier_ index in maps.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If align_corners=False, an example would be as follows.
 | 
				
			||||||
 | 
					        If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
 | 
				
			||||||
 | 
					        is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
 | 
				
			||||||
 | 
					        whose color is given by maps[i][700, 40].
 | 
				
			||||||
 | 
					        In this case, padding_mode even matters for values in verts_uvs
 | 
				
			||||||
 | 
					        slightly above 0 or slightly below 1. In this case, it matters if the
 | 
				
			||||||
 | 
					        first value is outside the interval [0.0005, 0.9995] or if the second
 | 
				
			||||||
 | 
					        is outside the interval [0.005, 0.995].
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.padding_mode = padding_mode
 | 
				
			||||||
 | 
					        self.align_corners = align_corners
 | 
				
			||||||
        if isinstance(faces_uvs, (list, tuple)):
 | 
					        if isinstance(faces_uvs, (list, tuple)):
 | 
				
			||||||
            for fv in faces_uvs:
 | 
					            for fv in faces_uvs:
 | 
				
			||||||
                # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
 | 
					                # pyre-fixme[16]: `Tensor` has no attribute `ndim`.
 | 
				
			||||||
@ -632,7 +666,7 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
                raise ValueError("Expected one texture map per mesh in the batch.")
 | 
					                raise ValueError("Expected one texture map per mesh in the batch.")
 | 
				
			||||||
            self._maps_list = maps
 | 
					            self._maps_list = maps
 | 
				
			||||||
            if self._N > 0:
 | 
					            if self._N > 0:
 | 
				
			||||||
                maps = _pad_texture_maps(maps)
 | 
					                maps = _pad_texture_maps(maps, align_corners=self.align_corners)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                maps = torch.empty(
 | 
					                maps = torch.empty(
 | 
				
			||||||
                    (self._N, 0, 0, 3), dtype=torch.float32, device=self.device
 | 
					                    (self._N, 0, 0, 3), dtype=torch.float32, device=self.device
 | 
				
			||||||
@ -698,11 +732,19 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
        # if index has multiple values then faces/verts/maps may be a list of tensors
 | 
					        # if index has multiple values then faces/verts/maps may be a list of tensors
 | 
				
			||||||
        if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
 | 
					        if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
 | 
				
			||||||
            new_tex = self.__class__(
 | 
					            new_tex = self.__class__(
 | 
				
			||||||
                faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps
 | 
					                faces_uvs=faces_uvs,
 | 
				
			||||||
 | 
					                verts_uvs=verts_uvs,
 | 
				
			||||||
 | 
					                maps=maps,
 | 
				
			||||||
 | 
					                padding_mode=self.padding_mode,
 | 
				
			||||||
 | 
					                align_corners=self.align_corners,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
 | 
					        elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
 | 
				
			||||||
            new_tex = self.__class__(
 | 
					            new_tex = self.__class__(
 | 
				
			||||||
                faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps]
 | 
					                faces_uvs=[faces_uvs],
 | 
				
			||||||
 | 
					                verts_uvs=[verts_uvs],
 | 
				
			||||||
 | 
					                maps=[maps],
 | 
				
			||||||
 | 
					                padding_mode=self.padding_mode,
 | 
				
			||||||
 | 
					                align_corners=self.align_corners,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError("Not all values are provided in the correct format")
 | 
					            raise ValueError("Not all values are provided in the correct format")
 | 
				
			||||||
@ -785,6 +827,8 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
            maps=new_props["maps_padded"],
 | 
					            maps=new_props["maps_padded"],
 | 
				
			||||||
            faces_uvs=new_props["faces_uvs_padded"],
 | 
					            faces_uvs=new_props["faces_uvs_padded"],
 | 
				
			||||||
            verts_uvs=new_props["verts_uvs_padded"],
 | 
					            verts_uvs=new_props["verts_uvs_padded"],
 | 
				
			||||||
 | 
					            padding_mode=self.padding_mode,
 | 
				
			||||||
 | 
					            align_corners=self.align_corners,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
 | 
					        new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
 | 
				
			||||||
@ -859,7 +903,12 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
        texture_maps = torch.flip(texture_maps, [2])  # flip y axis of the texture map
 | 
					        texture_maps = torch.flip(texture_maps, [2])  # flip y axis of the texture map
 | 
				
			||||||
        if texture_maps.device != pixel_uvs.device:
 | 
					        if texture_maps.device != pixel_uvs.device:
 | 
				
			||||||
            texture_maps = texture_maps.to(pixel_uvs.device)
 | 
					            texture_maps = texture_maps.to(pixel_uvs.device)
 | 
				
			||||||
        texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
 | 
					        texels = F.grid_sample(
 | 
				
			||||||
 | 
					            texture_maps,
 | 
				
			||||||
 | 
					            pixel_uvs,
 | 
				
			||||||
 | 
					            align_corners=self.align_corners,
 | 
				
			||||||
 | 
					            padding_mode=self.padding_mode,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # texels now has shape (NK, C, H_out, W_out)
 | 
					        # texels now has shape (NK, C, H_out, W_out)
 | 
				
			||||||
        texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
 | 
					        texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
 | 
				
			||||||
        return texels
 | 
					        return texels
 | 
				
			||||||
@ -881,6 +930,17 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
        if not tex_types_same:
 | 
					        if not tex_types_same:
 | 
				
			||||||
            raise ValueError("All textures must be of type TexturesUV.")
 | 
					            raise ValueError("All textures must be of type TexturesUV.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        padding_modes_same = all(
 | 
				
			||||||
 | 
					            tex.padding_mode == self.padding_mode for tex in textures
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if not padding_modes_same:
 | 
				
			||||||
 | 
					            raise ValueError("All textures must have the same padding_mode.")
 | 
				
			||||||
 | 
					        align_corners_same = all(
 | 
				
			||||||
 | 
					            tex.align_corners == self.align_corners for tex in textures
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        if not align_corners_same:
 | 
				
			||||||
 | 
					            raise ValueError("All textures must have the same align_corners value.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        verts_uvs_list = []
 | 
					        verts_uvs_list = []
 | 
				
			||||||
        faces_uvs_list = []
 | 
					        faces_uvs_list = []
 | 
				
			||||||
        maps_list = []
 | 
					        maps_list = []
 | 
				
			||||||
@ -896,7 +956,11 @@ class TexturesUV(TexturesBase):
 | 
				
			|||||||
            maps_list += tex_map_list
 | 
					            maps_list += tex_map_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        new_tex = self.__class__(
 | 
					        new_tex = self.__class__(
 | 
				
			||||||
            maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list
 | 
					            maps=maps_list,
 | 
				
			||||||
 | 
					            verts_uvs=verts_uvs_list,
 | 
				
			||||||
 | 
					            faces_uvs=faces_uvs_list,
 | 
				
			||||||
 | 
					            padding_mode=self.padding_mode,
 | 
				
			||||||
 | 
					            align_corners=self.align_corners,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        new_tex._num_faces_per_mesh = num_faces_per_mesh
 | 
					        new_tex._num_faces_per_mesh = num_faces_per_mesh
 | 
				
			||||||
        return new_tex
 | 
					        return new_tex
 | 
				
			||||||
 | 
				
			|||||||
@ -1505,9 +1505,13 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
 | 
				
			|||||||
    Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
 | 
					    Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
 | 
				
			||||||
    must all be on the same device. If include_textures is true, they must all
 | 
					    must all be on the same device. If include_textures is true, they must all
 | 
				
			||||||
    be compatible, either all or none having textures, and all the Textures
 | 
					    be compatible, either all or none having textures, and all the Textures
 | 
				
			||||||
    objects having the same members. If  include_textures is False, textures are
 | 
					    objects being the same type. If include_textures is False, textures are
 | 
				
			||||||
    ignored.
 | 
					    ignored.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If the textures are TexturesAtlas then being the same type includes having
 | 
				
			||||||
 | 
					    the same resolution. If they are TexturesUV then it includes having the same
 | 
				
			||||||
 | 
					    align_corners and padding_mode.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        meshes: list of meshes.
 | 
					        meshes: list of meshes.
 | 
				
			||||||
        include_textures: (bool) whether to try to join the textures.
 | 
					        include_textures: (bool) whether to try to join the textures.
 | 
				
			||||||
 | 
				
			|||||||
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 43 KiB After Width: | Height: | Size: 43 KiB  | 
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 31 KiB After Width: | Height: | Size: 31 KiB  | 
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 30 KiB After Width: | Height: | Size: 30 KiB  | 
@ -373,10 +373,10 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            self.assertFalse(new_mesh.verts_packed().requires_grad)
 | 
					            self.assertFalse(new_mesh.verts_packed().requires_grad)
 | 
				
			||||||
            self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
 | 
					            self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
 | 
				
			||||||
            self.assertTrue(new_mesh.verts_padded().requires_grad == False)
 | 
					            self.assertFalse(new_mesh.verts_padded().requires_grad)
 | 
				
			||||||
            self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
 | 
					            self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
 | 
				
			||||||
            for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
 | 
					            for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
 | 
				
			||||||
                self.assertTrue(newv.requires_grad == False)
 | 
					                self.assertFalse(newv.requires_grad)
 | 
				
			||||||
                self.assertClose(newv, v)
 | 
					                self.assertClose(newv, v)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_laplacian_packed(self):
 | 
					    def test_laplacian_packed(self):
 | 
				
			||||||
 | 
				
			|||||||
@ -411,11 +411,11 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
                new_clouds = clouds.detach()
 | 
					                new_clouds = clouds.detach()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for cloud in new_clouds.points_list():
 | 
					                for cloud in new_clouds.points_list():
 | 
				
			||||||
                    self.assertTrue(cloud.requires_grad == False)
 | 
					                    self.assertFalse(cloud.requires_grad)
 | 
				
			||||||
                for normal in new_clouds.normals_list():
 | 
					                for normal in new_clouds.normals_list():
 | 
				
			||||||
                    self.assertTrue(normal.requires_grad == False)
 | 
					                    self.assertFalse(normal.requires_grad)
 | 
				
			||||||
                for feats in new_clouds.features_list():
 | 
					                for feats in new_clouds.features_list():
 | 
				
			||||||
                    self.assertTrue(feats.requires_grad == False)
 | 
					                    self.assertFalse(feats.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for attrib in [
 | 
					                for attrib in [
 | 
				
			||||||
                    "points_packed",
 | 
					                    "points_packed",
 | 
				
			||||||
@ -425,9 +425,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
                    "normals_padded",
 | 
					                    "normals_padded",
 | 
				
			||||||
                    "features_padded",
 | 
					                    "features_padded",
 | 
				
			||||||
                ]:
 | 
					                ]:
 | 
				
			||||||
                    self.assertTrue(
 | 
					                    self.assertFalse(getattr(new_clouds, attrib)().requires_grad)
 | 
				
			||||||
                        getattr(new_clouds, attrib)().requires_grad == False
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                self.assertCloudsEqual(clouds, new_clouds)
 | 
					                self.assertCloudsEqual(clouds, new_clouds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -443,19 +443,26 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
            dists=pix_to_face,
 | 
					            dists=pix_to_face,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
 | 
					        for align_corners in [True, False]:
 | 
				
			||||||
        meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
 | 
					            tex = TexturesUV(
 | 
				
			||||||
        mesh_textures = meshes.textures
 | 
					                maps=tex_map,
 | 
				
			||||||
        texels = mesh_textures.sample_textures(fragments)
 | 
					                faces_uvs=[face_uvs],
 | 
				
			||||||
 | 
					                verts_uvs=[vert_uvs],
 | 
				
			||||||
 | 
					                align_corners=align_corners,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
 | 
				
			||||||
 | 
					            mesh_textures = meshes.textures
 | 
				
			||||||
 | 
					            texels = mesh_textures.sample_textures(fragments)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Expected output
 | 
					            # Expected output
 | 
				
			||||||
        pixel_uvs = interpolated_uvs * 2.0 - 1.0
 | 
					            pixel_uvs = interpolated_uvs * 2.0 - 1.0
 | 
				
			||||||
        pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
 | 
					            pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
 | 
				
			||||||
        tex_map = torch.flip(tex_map, [1])
 | 
					            tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2)
 | 
				
			||||||
        tex_map = tex_map.permute(0, 3, 1, 2)
 | 
					            tex_map_ = torch.cat([tex_map_, tex_map_], dim=0)
 | 
				
			||||||
        tex_map = torch.cat([tex_map, tex_map], dim=0)
 | 
					            expected_out = F.grid_sample(
 | 
				
			||||||
        expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
 | 
					                tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border"
 | 
				
			||||||
        self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
 | 
					            )
 | 
				
			||||||
 | 
					            self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_textures_uv_init_fail(self):
 | 
					    def test_textures_uv_init_fail(self):
 | 
				
			||||||
        # Maps has wrong shape
 | 
					        # Maps has wrong shape
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user