mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Improve textures type annotations
Summary: Improve type annotations for textures and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28942630 fbshipit-source-id: 422f2bdf07b435869461ca103d71473aa0c2b814
This commit is contained in:
parent
d76c00721c
commit
c710d8c101
@ -273,7 +273,7 @@ class TexturesBase:
|
|||||||
|
|
||||||
|
|
||||||
def Textures(
|
def Textures(
|
||||||
maps: Union[List, torch.Tensor, None] = None,
|
maps: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||||
faces_uvs: Optional[torch.Tensor] = None,
|
faces_uvs: Optional[torch.Tensor] = None,
|
||||||
verts_uvs: Optional[torch.Tensor] = None,
|
verts_uvs: Optional[torch.Tensor] = None,
|
||||||
verts_rgb: Optional[torch.Tensor] = None,
|
verts_rgb: Optional[torch.Tensor] = None,
|
||||||
@ -305,20 +305,19 @@ def Textures(
|
|||||||
PendingDeprecationWarning,
|
PendingDeprecationWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if all(x is not None for x in [faces_uvs, verts_uvs, maps]):
|
if faces_uvs is not None and verts_uvs is not None and maps is not None:
|
||||||
# pyre-fixme[6]: Expected `Union[List[torch.Tensor], torch.Tensor]` for 1st
|
|
||||||
# param but got `Union[None, List[typing.Any], torch.Tensor]`.
|
|
||||||
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
|
return TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
|
||||||
elif verts_rgb is not None:
|
|
||||||
|
if verts_rgb is not None:
|
||||||
return TexturesVertex(verts_features=verts_rgb)
|
return TexturesVertex(verts_features=verts_rgb)
|
||||||
else:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
|
"Textures either requires all three of (faces uvs, verts uvs, maps) or verts rgb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TexturesAtlas(TexturesBase):
|
class TexturesAtlas(TexturesBase):
|
||||||
def __init__(self, atlas: Union[torch.Tensor, List, None]):
|
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]):
|
||||||
"""
|
"""
|
||||||
A texture representation where each face has a square texture map.
|
A texture representation where each face has a square texture map.
|
||||||
This is based on the implementation from SoftRasterizer [1].
|
This is based on the implementation from SoftRasterizer [1].
|
||||||
@ -369,22 +368,17 @@ class TexturesAtlas(TexturesBase):
|
|||||||
self.device = atlas[0].device
|
self.device = atlas[0].device
|
||||||
|
|
||||||
elif torch.is_tensor(atlas):
|
elif torch.is_tensor(atlas):
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `ndim`.
|
|
||||||
if atlas.ndim != 5:
|
if atlas.ndim != 5:
|
||||||
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r"
|
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r"
|
||||||
raise ValueError(msg % repr(atlas.ndim))
|
raise ValueError(msg % repr(atlas.ndim))
|
||||||
self._atlas_padded = atlas
|
self._atlas_padded = atlas
|
||||||
self._atlas_list = None
|
self._atlas_list = None
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `device`.
|
|
||||||
self.device = atlas.device
|
self.device = atlas.device
|
||||||
|
|
||||||
# These values may be overridden when textures is
|
# These values may be overridden when textures is
|
||||||
# passed into the Meshes constructor. For more details
|
# passed into the Meshes constructor. For more details
|
||||||
# refer to the __init__ of Meshes.
|
# refer to the __init__ of Meshes.
|
||||||
# pyre-fixme[6]: Expected `Sized` for 1st param but got
|
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
self._N = len(atlas)
|
self._N = len(atlas)
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
|
||||||
max_F = atlas.shape[1]
|
max_F = atlas.shape[1]
|
||||||
self._num_faces_per_mesh = [max_F] * self._N
|
self._num_faces_per_mesh = [max_F] * self._N
|
||||||
else:
|
else:
|
||||||
@ -699,9 +693,7 @@ class TexturesUV(TexturesBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Expected verts_uvs to be a tensor or list")
|
raise ValueError("Expected verts_uvs to be a tensor or list")
|
||||||
|
|
||||||
if torch.is_tensor(maps):
|
if isinstance(maps, torch.Tensor):
|
||||||
# pyre-fixme[16]: `List` has no attribute `ndim`.
|
|
||||||
# pyre-fixme[16]: `List` has no attribute `shape`.
|
|
||||||
if maps.ndim != 4 or maps.shape[0] != self._N:
|
if maps.ndim != 4 or maps.shape[0] != self._N:
|
||||||
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
|
msg = "Expected maps to be of shape (N, H, W, 3); got %r"
|
||||||
raise ValueError(msg % repr(maps.shape))
|
raise ValueError(msg % repr(maps.shape))
|
||||||
@ -726,7 +718,7 @@ class TexturesUV(TexturesBase):
|
|||||||
|
|
||||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> "TexturesUV":
|
||||||
tex = self.__class__(
|
tex = self.__class__(
|
||||||
self.maps_padded().clone(),
|
self.maps_padded().clone(),
|
||||||
self.faces_uvs_padded().clone(),
|
self.faces_uvs_padded().clone(),
|
||||||
@ -747,7 +739,7 @@ class TexturesUV(TexturesBase):
|
|||||||
tex.valid = self.valid.clone()
|
tex.valid = self.valid.clone()
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
def detach(self):
|
def detach(self) -> "TexturesUV":
|
||||||
tex = self.__class__(
|
tex = self.__class__(
|
||||||
self.maps_padded().detach(),
|
self.maps_padded().detach(),
|
||||||
self.faces_uvs_padded().detach(),
|
self.faces_uvs_padded().detach(),
|
||||||
@ -1185,7 +1177,7 @@ class TexturesUV(TexturesBase):
|
|||||||
padding_mode=self.padding_mode,
|
padding_mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def centers_for_image(self, index):
|
def centers_for_image(self, index: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the locations in the texture map which correspond to the given
|
Return the locations in the texture map which correspond to the given
|
||||||
verts_uvs, for one of the meshes. This is potentially useful for
|
verts_uvs, for one of the meshes. This is potentially useful for
|
||||||
@ -1279,7 +1271,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
# refer to the __init__ of Meshes.
|
# refer to the __init__ of Meshes.
|
||||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self) -> "TexturesVertex":
|
||||||
tex = self.__class__(self.verts_features_padded().clone())
|
tex = self.__class__(self.verts_features_padded().clone())
|
||||||
if self._verts_features_list is not None:
|
if self._verts_features_list is not None:
|
||||||
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
||||||
@ -1287,7 +1279,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
tex.valid = self.valid.clone()
|
tex.valid = self.valid.clone()
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
def detach(self):
|
def detach(self) -> "TexturesVertex":
|
||||||
tex = self.__class__(self.verts_features_padded().detach())
|
tex = self.__class__(self.verts_features_padded().detach())
|
||||||
if self._verts_features_list is not None:
|
if self._verts_features_list is not None:
|
||||||
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
|
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
|
||||||
@ -1295,7 +1287,7 @@ class TexturesVertex(TexturesBase):
|
|||||||
tex.valid = self.valid.detach()
|
tex.valid = self.valid.detach()
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index) -> "TexturesVertex":
|
||||||
props = ["verts_features_list", "_num_verts_per_mesh"]
|
props = ["verts_features_list", "_num_verts_per_mesh"]
|
||||||
new_props = self._getitem(index, props)
|
new_props = self._getitem(index, props)
|
||||||
verts_features = new_props["verts_features_list"]
|
verts_features = new_props["verts_features_list"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user