mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Improve textures type annotations
Summary: Improve type annotations for textures and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28942630 fbshipit-source-id: 422f2bdf07b435869461ca103d71473aa0c2b814
This commit is contained in:
parent
d76c00721c
commit
c710d8c101
@ -273,7 +273,7 @@ class TexturesBase:
|
||||
|
||||
|
||||
def Textures(
|
||||
maps: Union[List, torch.Tensor, None] = None,
|
||||
maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
faces_uvs: Optional[torch.Tensor] = None,
|
||||
verts_uvs: Optional[torch.Tensor] = None,
|
||||
verts_rgb: Optional[torch.Tensor] = None,
|
||||
@ -305,20 +305,19 @@ def Textures(
|
||||
PendingDeprecationWarning,
|
||||
)
|
||||
|
||||
if all(x is not None for x in [faces_uvs, verts_uvs, maps]):
|
||||
# pyre-fixme[6]: Expected `Union[List[torch.Tensor], torch.Tensor]` for 1st
|
||||
# param but got `Union[None, List[typing.Any], torch.Tensor]`.
|
||||
if faces_uvs is not None and verts_uvs is not None and maps is not None:
|
||||
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
|
||||
elif verts_rgb is not None:
|
||||
|
||||
if verts_rgb is not None:
|
||||
return TexturesVertex(verts_features=verts_rgb)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
|
||||
)
|
||||
|
||||
|
||||
class TexturesAtlas(TexturesBase):
|
||||
def __init__(self, atlas: Union[torch.Tensor, List, None]):
|
||||
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
"""
|
||||
A texture representation where each face has a square texture map.
|
||||
This is based on the implementation from SoftRasterizer [1].
|
||||
@ -369,22 +368,17 @@ class TexturesAtlas(TexturesBase):
|
||||
self.device = atlas[0].device
|
||||
|
||||
elif torch.is_tensor(atlas):
|
||||
# pyre-fixme[16]: `Optional` has no attribute `ndim`.
|
||||
if atlas.ndim != 5:
|
||||
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r"
|
||||
raise ValueError(msg % repr(atlas.ndim))
|
||||
self._atlas_padded = atlas
|
||||
self._atlas_list = None
|
||||
# pyre-fixme[16]: `Optional` has no attribute `device`.
|
||||
self.device = atlas.device
|
||||
|
||||
# These values may be overridden when textures is
|
||||
# passed into the Meshes constructor. For more details
|
||||
# refer to the __init__ of Meshes.
|
||||
# pyre-fixme[6]: Expected `Sized` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
self._N = len(atlas)
|
||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
||||
max_F = atlas.shape[1]
|
||||
self._num_faces_per_mesh = [max_F] * self._N
|
||||
else:
|
||||
@ -699,9 +693,7 @@ class TexturesUV(TexturesBase):
|
||||
else:
|
||||
raise ValueError("Expected verts_uvs to be a tensor or list")
|
||||
|
||||
if torch.is_tensor(maps):
|
||||
# pyre-fixme[16]: `List` has no attribute `ndim`.
|
||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
||||
if isinstance(maps, torch.Tensor):
|
||||
if maps.ndim != 4 or maps.shape[0] != self._N:
|
||||
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
@ -726,7 +718,7 @@ class TexturesUV(TexturesBase):
|
||||
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "TexturesUV":
|
||||
tex = self.__class__(
|
||||
self.maps_padded().clone(),
|
||||
self.faces_uvs_padded().clone(),
|
||||
@ -747,7 +739,7 @@ class TexturesUV(TexturesBase):
|
||||
tex.valid = self.valid.clone()
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
def detach(self) -> "TexturesUV":
|
||||
tex = self.__class__(
|
||||
self.maps_padded().detach(),
|
||||
self.faces_uvs_padded().detach(),
|
||||
@ -1185,7 +1177,7 @@ class TexturesUV(TexturesBase):
|
||||
padding_mode=self.padding_mode,
|
||||
)
|
||||
|
||||
def centers_for_image(self, index):
|
||||
def centers_for_image(self, index: int) -> torch.Tensor:
|
||||
"""
|
||||
Return the locations in the texture map which correspond to the given
|
||||
verts_uvs, for one of the meshes. This is potentially useful for
|
||||
@ -1279,7 +1271,7 @@ class TexturesVertex(TexturesBase):
|
||||
# refer to the __init__ of Meshes.
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def clone(self):
|
||||
def clone(self) -> "TexturesVertex":
|
||||
tex = self.__class__(self.verts_features_padded().clone())
|
||||
if self._verts_features_list is not None:
|
||||
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
||||
@ -1287,7 +1279,7 @@ class TexturesVertex(TexturesBase):
|
||||
tex.valid = self.valid.clone()
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
def detach(self) -> "TexturesVertex":
|
||||
tex = self.__class__(self.verts_features_padded().detach())
|
||||
if self._verts_features_list is not None:
|
||||
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
|
||||
@ -1295,7 +1287,7 @@ class TexturesVertex(TexturesBase):
|
||||
tex.valid = self.valid.detach()
|
||||
return tex
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index) -> "TexturesVertex":
|
||||
props = ["verts_features_list", "_num_verts_per_mesh"]
|
||||
new_props = self._getitem(index, props)
|
||||
verts_features = new_props["verts_features_list"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user