Improve textures type annotations

Summary: Improve type annotations for textures and remove a few pyre fixmes

Reviewed By: nikhilaravi

Differential Revision: D28942630

fbshipit-source-id: 422f2bdf07b435869461ca103d71473aa0c2b814
This commit is contained in:
Patrick Labatut 2021-06-08 02:14:59 -07:00 committed by Facebook GitHub Bot
parent d76c00721c
commit c710d8c101

View File

@ -273,7 +273,7 @@ class TexturesBase:
def Textures( def Textures(
maps: Union[List, torch.Tensor, None] = None, maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
faces_uvs: Optional[torch.Tensor] = None, faces_uvs: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None,
verts_rgb: Optional[torch.Tensor] = None, verts_rgb: Optional[torch.Tensor] = None,
@ -305,20 +305,19 @@ def Textures(
PendingDeprecationWarning, PendingDeprecationWarning,
) )
if all(x is not None for x in [faces_uvs, verts_uvs, maps]): if faces_uvs is not None and verts_uvs is not None and maps is not None:
# pyre-fixme[6]: Expected `Union[List[torch.Tensor], torch.Tensor]` for 1st
# param but got `Union[None, List[typing.Any], torch.Tensor]`.
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs) return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
elif verts_rgb is not None:
if verts_rgb is not None:
return TexturesVertex(verts_features=verts_rgb) return TexturesVertex(verts_features=verts_rgb)
else:
raise ValueError( raise ValueError(
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb" "Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
) )
class TexturesAtlas(TexturesBase): class TexturesAtlas(TexturesBase):
def __init__(self, atlas: Union[torch.Tensor, List, None]): def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]):
""" """
A texture representation where each face has a square texture map. A texture representation where each face has a square texture map.
This is based on the implementation from SoftRasterizer [1]. This is based on the implementation from SoftRasterizer [1].
@ -369,22 +368,17 @@ class TexturesAtlas(TexturesBase):
self.device = atlas[0].device self.device = atlas[0].device
elif torch.is_tensor(atlas): elif torch.is_tensor(atlas):
# pyre-fixme[16]: `Optional` has no attribute `ndim`.
if atlas.ndim != 5: if atlas.ndim != 5:
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r" msg = "Expected atlas to be of shape (N, F, R, R, D); got %r"
raise ValueError(msg % repr(atlas.ndim)) raise ValueError(msg % repr(atlas.ndim))
self._atlas_padded = atlas self._atlas_padded = atlas
self._atlas_list = None self._atlas_list = None
# pyre-fixme[16]: `Optional` has no attribute `device`.
self.device = atlas.device self.device = atlas.device
# These values may be overridden when textures is # These values may be overridden when textures is
# passed into the Meshes constructor. For more details # passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes. # refer to the __init__ of Meshes.
# pyre-fixme[6]: Expected `Sized` for 1st param but got
# `Optional[torch.Tensor]`.
self._N = len(atlas) self._N = len(atlas)
# pyre-fixme[16]: `Optional` has no attribute `shape`.
max_F = atlas.shape[1] max_F = atlas.shape[1]
self._num_faces_per_mesh = [max_F] * self._N self._num_faces_per_mesh = [max_F] * self._N
else: else:
@ -699,9 +693,7 @@ class TexturesUV(TexturesBase):
else: else:
raise ValueError("Expected verts_uvs to be a tensor or list") raise ValueError("Expected verts_uvs to be a tensor or list")
if torch.is_tensor(maps): if isinstance(maps, torch.Tensor):
# pyre-fixme[16]: `List` has no attribute `ndim`.
# pyre-fixme[16]: `List` has no attribute `shape`.
if maps.ndim != 4 or maps.shape[0] != self._N: if maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, 3); got %r" msg = "Expected maps to be of shape (N, H, W, 3); got %r"
raise ValueError(msg % repr(maps.shape)) raise ValueError(msg % repr(maps.shape))
@ -726,7 +718,7 @@ class TexturesUV(TexturesBase):
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self): def clone(self) -> "TexturesUV":
tex = self.__class__( tex = self.__class__(
self.maps_padded().clone(), self.maps_padded().clone(),
self.faces_uvs_padded().clone(), self.faces_uvs_padded().clone(),
@ -747,7 +739,7 @@ class TexturesUV(TexturesBase):
tex.valid = self.valid.clone() tex.valid = self.valid.clone()
return tex return tex
def detach(self): def detach(self) -> "TexturesUV":
tex = self.__class__( tex = self.__class__(
self.maps_padded().detach(), self.maps_padded().detach(),
self.faces_uvs_padded().detach(), self.faces_uvs_padded().detach(),
@ -1185,7 +1177,7 @@ class TexturesUV(TexturesBase):
padding_mode=self.padding_mode, padding_mode=self.padding_mode,
) )
def centers_for_image(self, index): def centers_for_image(self, index: int) -> torch.Tensor:
""" """
Return the locations in the texture map which correspond to the given Return the locations in the texture map which correspond to the given
verts_uvs, for one of the meshes. This is potentially useful for verts_uvs, for one of the meshes. This is potentially useful for
@ -1279,7 +1271,7 @@ class TexturesVertex(TexturesBase):
# refer to the __init__ of Meshes. # refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self): def clone(self) -> "TexturesVertex":
tex = self.__class__(self.verts_features_padded().clone()) tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None: if self._verts_features_list is not None:
tex._verts_features_list = [f.clone() for f in self._verts_features_list] tex._verts_features_list = [f.clone() for f in self._verts_features_list]
@ -1287,7 +1279,7 @@ class TexturesVertex(TexturesBase):
tex.valid = self.valid.clone() tex.valid = self.valid.clone()
return tex return tex
def detach(self): def detach(self) -> "TexturesVertex":
tex = self.__class__(self.verts_features_padded().detach()) tex = self.__class__(self.verts_features_padded().detach())
if self._verts_features_list is not None: if self._verts_features_list is not None:
tex._verts_features_list = [f.detach() for f in self._verts_features_list] tex._verts_features_list = [f.detach() for f in self._verts_features_list]
@ -1295,7 +1287,7 @@ class TexturesVertex(TexturesBase):
tex.valid = self.valid.detach() tex.valid = self.valid.detach()
return tex return tex
def __getitem__(self, index): def __getitem__(self, index) -> "TexturesVertex":
props = ["verts_features_list", "_num_verts_per_mesh"] props = ["verts_features_list", "_num_verts_per_mesh"]
new_props = self._getitem(index, props) new_props = self._getitem(index, props)
verts_features = new_props["verts_features_list"] verts_features = new_props["verts_features_list"]