Improve textures type annotations

Summary: Improve type annotations for textures and remove a few pyre fixmes

Reviewed By: nikhilaravi

Differential Revision: D28942630

fbshipit-source-id: 422f2bdf07b435869461ca103d71473aa0c2b814
This commit is contained in:
Patrick Labatut 2021-06-08 02:14:59 -07:00 committed by Facebook GitHub Bot
parent d76c00721c
commit c710d8c101

View File

@ -273,7 +273,7 @@ class TexturesBase:
def Textures(
maps: Union[List, torch.Tensor, None] = None,
maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
faces_uvs: Optional[torch.Tensor] = None,
verts_uvs: Optional[torch.Tensor] = None,
verts_rgb: Optional[torch.Tensor] = None,
@ -305,20 +305,19 @@ def Textures(
PendingDeprecationWarning,
)
if all(x is not None for x in [faces_uvs, verts_uvs, maps]):
# pyre-fixme[6]: Expected `Union[List[torch.Tensor], torch.Tensor]` for 1st
# param but got `Union[None, List[typing.Any], torch.Tensor]`.
if faces_uvs is not None and verts_uvs is not None and maps is not None:
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
elif verts_rgb is not None:
if verts_rgb is not None:
return TexturesVertex(verts_features=verts_rgb)
else:
raise ValueError(
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
)
class TexturesAtlas(TexturesBase):
def __init__(self, atlas: Union[torch.Tensor, List, None]):
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]):
"""
A texture representation where each face has a square texture map.
This is based on the implementation from SoftRasterizer [1].
@ -369,22 +368,17 @@ class TexturesAtlas(TexturesBase):
self.device = atlas[0].device
elif torch.is_tensor(atlas):
# pyre-fixme[16]: `Optional` has no attribute `ndim`.
if atlas.ndim != 5:
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r"
raise ValueError(msg % repr(atlas.ndim))
self._atlas_padded = atlas
self._atlas_list = None
# pyre-fixme[16]: `Optional` has no attribute `device`.
self.device = atlas.device
# These values may be overridden when textures is
# passed into the Meshes constructor. For more details
# refer to the __init__ of Meshes.
# pyre-fixme[6]: Expected `Sized` for 1st param but got
# `Optional[torch.Tensor]`.
self._N = len(atlas)
# pyre-fixme[16]: `Optional` has no attribute `shape`.
max_F = atlas.shape[1]
self._num_faces_per_mesh = [max_F] * self._N
else:
@ -699,9 +693,7 @@ class TexturesUV(TexturesBase):
else:
raise ValueError("Expected verts_uvs to be a tensor or list")
if torch.is_tensor(maps):
# pyre-fixme[16]: `List` has no attribute `ndim`.
# pyre-fixme[16]: `List` has no attribute `shape`.
if isinstance(maps, torch.Tensor):
if maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
raise ValueError(msg % repr(maps.shape))
@ -726,7 +718,7 @@ class TexturesUV(TexturesBase):
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self):
def clone(self) -> "TexturesUV":
tex = self.__class__(
self.maps_padded().clone(),
self.faces_uvs_padded().clone(),
@ -747,7 +739,7 @@ class TexturesUV(TexturesBase):
tex.valid = self.valid.clone()
return tex
def detach(self):
def detach(self) -> "TexturesUV":
tex = self.__class__(
self.maps_padded().detach(),
self.faces_uvs_padded().detach(),
@ -1185,7 +1177,7 @@ class TexturesUV(TexturesBase):
padding_mode=self.padding_mode,
)
def centers_for_image(self, index):
def centers_for_image(self, index: int) -> torch.Tensor:
"""
Return the locations in the texture map which correspond to the given
verts_uvs, for one of the meshes. This is potentially useful for
@ -1279,7 +1271,7 @@ class TexturesVertex(TexturesBase):
# refer to the __init__ of Meshes.
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def clone(self):
def clone(self) -> "TexturesVertex":
tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None:
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
@ -1287,7 +1279,7 @@ class TexturesVertex(TexturesBase):
tex.valid = self.valid.clone()
return tex
def detach(self):
def detach(self) -> "TexturesVertex":
tex = self.__class__(self.verts_features_padded().detach())
if self._verts_features_list is not None:
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
@ -1295,7 +1287,7 @@ class TexturesVertex(TexturesBase):
tex.valid = self.valid.detach()
return tex
def __getitem__(self, index):
def __getitem__(self, index) -> "TexturesVertex":
props = ["verts_features_list", "_num_verts_per_mesh"]
new_props = self._getitem(index, props)
verts_features = new_props["verts_features_list"]