mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
texture map list validation
Summary: Add some more validation of a list of texture maps. Move the initialisation of maps_padded to a new function to reduce complexity. Reviewed By: nikhilaravi Differential Revision: D29263443 fbshipit-source-id: 153e262d2e9af21090570768020fca019e364024
This commit is contained in:
parent
2a0660baab
commit
279f4a154d
@ -707,31 +707,43 @@ class TexturesUV(TexturesBase):
|
||||
else:
|
||||
raise ValueError("Expected verts_uvs to be a tensor or list")
|
||||
|
||||
if isinstance(maps, torch.Tensor):
|
||||
if maps.ndim != 4 or maps.shape[0] != self._N:
|
||||
msg = "Expected maps to be of shape (N, H, W, C); got %r"
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
self._maps_padded = maps
|
||||
self._maps_list = None
|
||||
elif isinstance(maps, (list, tuple)):
|
||||
if len(maps) != self._N:
|
||||
raise ValueError("Expected one texture map per mesh in the batch.")
|
||||
if isinstance(maps, (list, tuple)):
|
||||
self._maps_list = maps
|
||||
if self._N > 0:
|
||||
maps = _pad_texture_maps(maps, align_corners=self.align_corners)
|
||||
else:
|
||||
maps = torch.empty(
|
||||
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._maps_padded = maps
|
||||
else:
|
||||
raise ValueError("Expected maps to be a tensor or list.")
|
||||
self._maps_list = None
|
||||
self._maps_padded = self._format_maps_padded(maps)
|
||||
|
||||
if self._maps_padded.device != self.device:
|
||||
raise ValueError("maps must be on the same device as verts/faces uvs.")
|
||||
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def _format_maps_padded(
|
||||
self, maps: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
if isinstance(maps, torch.Tensor):
|
||||
if maps.ndim != 4 or maps.shape[0] != self._N:
|
||||
msg = "Expected maps to be of shape (N, H, W, C); got %r"
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
return maps
|
||||
|
||||
if isinstance(maps, (list, tuple)):
|
||||
if len(maps) != self._N:
|
||||
raise ValueError("Expected one texture map per mesh in the batch.")
|
||||
if self._N > 0:
|
||||
if not all(map.ndim == 3 for map in maps):
|
||||
raise ValueError("Invalid number of dimensions in texture maps")
|
||||
if not all(map.shape[2] == maps[0].shape[2] for map in maps):
|
||||
raise ValueError("Inconsistent number of channels in maps")
|
||||
maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners)
|
||||
else:
|
||||
maps_padded = torch.empty(
|
||||
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
|
||||
)
|
||||
return maps_padded
|
||||
|
||||
raise ValueError("Expected maps to be a tensor or list of tensors.")
|
||||
|
||||
def clone(self) -> "TexturesUV":
|
||||
tex = self.__class__(
|
||||
self.maps_padded().clone(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user