From c710d8c101fe0178f4536d1e9c531a010aac03ee Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Tue, 8 Jun 2021 02:14:59 -0700 Subject: [PATCH] Improve textures type annotations Summary: Improve type annotations for textures and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28942630 fbshipit-source-id: 422f2bdf07b435869461ca103d71473aa0c2b814 --- pytorch3d/renderer/mesh/textures.py | 40 ++++++++++++----------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index abfc0a54..6588c8e0 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -273,7 +273,7 @@ class TexturesBase: def Textures( - maps: Union[List, torch.Tensor, None] = None, + maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, faces_uvs: Optional[torch.Tensor] = None, verts_uvs: Optional[torch.Tensor] = None, verts_rgb: Optional[torch.Tensor] = None, @@ -305,20 +305,19 @@ def Textures( PendingDeprecationWarning, ) - if all(x is not None for x in [faces_uvs, verts_uvs, maps]): - # pyre-fixme[6]: Expected `Union[List[torch.Tensor], torch.Tensor]` for 1st - # param but got `Union[None, List[typing.Any], torch.Tensor]`. + if faces_uvs is not None and verts_uvs is not None and maps is not None: return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs) - elif verts_rgb is not None: + + if verts_rgb is not None: return TexturesVertex(verts_features=verts_rgb) - else: - raise ValueError( - "Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb" - ) + + raise ValueError( + "Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb" + ) class TexturesAtlas(TexturesBase): - def __init__(self, atlas: Union[torch.Tensor, List, None]): + def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]): """ A texture representation where each face has a square texture map. This is based on the implementation from SoftRasterizer [1]. @@ -369,22 +368,17 @@ class TexturesAtlas(TexturesBase): self.device = atlas[0].device elif torch.is_tensor(atlas): - # pyre-fixme[16]: `Optional` has no attribute `ndim`. if atlas.ndim != 5: msg = "Expected atlas to be of shape (N, F, R, R, D); got %r" raise ValueError(msg % repr(atlas.ndim)) self._atlas_padded = atlas self._atlas_list = None - # pyre-fixme[16]: `Optional` has no attribute `device`. self.device = atlas.device # These values may be overridden when textures is # passed into the Meshes constructor. For more details # refer to the __init__ of Meshes. - # pyre-fixme[6]: Expected `Sized` for 1st param but got - # `Optional[torch.Tensor]`. self._N = len(atlas) - # pyre-fixme[16]: `Optional` has no attribute `shape`. max_F = atlas.shape[1] self._num_faces_per_mesh = [max_F] * self._N else: @@ -699,9 +693,7 @@ class TexturesUV(TexturesBase): else: raise ValueError("Expected verts_uvs to be a tensor or list") - if torch.is_tensor(maps): - # pyre-fixme[16]: `List` has no attribute `ndim`. - # pyre-fixme[16]: `List` has no attribute `shape`. + if isinstance(maps, torch.Tensor): if maps.ndim != 4 or maps.shape[0] != self._N: msg = "Expected maps to be of shape (N, H, W, 3); got %r" raise ValueError(msg % repr(maps.shape)) @@ -726,7 +718,7 @@ class TexturesUV(TexturesBase): self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - def clone(self): + def clone(self) -> "TexturesUV": tex = self.__class__( self.maps_padded().clone(), self.faces_uvs_padded().clone(), @@ -747,7 +739,7 @@ class TexturesUV(TexturesBase): tex.valid = self.valid.clone() return tex - def detach(self): + def detach(self) -> "TexturesUV": tex = self.__class__( self.maps_padded().detach(), self.faces_uvs_padded().detach(), @@ -1185,7 +1177,7 @@ class TexturesUV(TexturesBase): padding_mode=self.padding_mode, ) - def centers_for_image(self, index): + def centers_for_image(self, index: int) -> torch.Tensor: """ Return the locations in the texture map which correspond to the given verts_uvs, for one of the meshes. This is potentially useful for @@ -1279,7 +1271,7 @@ class TexturesVertex(TexturesBase): # refer to the __init__ of Meshes. self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - def clone(self): + def clone(self) -> "TexturesVertex": tex = self.__class__(self.verts_features_padded().clone()) if self._verts_features_list is not None: tex._verts_features_list = [f.clone() for f in self._verts_features_list] @@ -1287,7 +1279,7 @@ class TexturesVertex(TexturesBase): tex.valid = self.valid.clone() return tex - def detach(self): + def detach(self) -> "TexturesVertex": tex = self.__class__(self.verts_features_padded().detach()) if self._verts_features_list is not None: tex._verts_features_list = [f.detach() for f in self._verts_features_list] @@ -1295,7 +1287,7 @@ class TexturesVertex(TexturesBase): tex.valid = self.valid.detach() return tex - def __getitem__(self, index): + def __getitem__(self, index) -> "TexturesVertex": props = ["verts_features_list", "_num_verts_per_mesh"] new_props = self._getitem(index, props) verts_features = new_props["verts_features_list"]