texture map list validation

Summary: Add some more validation of a list of texture maps. Move the initialisation of maps_padded to a new function to reduce complexity.

Reviewed By: nikhilaravi

Differential Revision: D29263443

fbshipit-source-id: 153e262d2e9af21090570768020fca019e364024
This commit is contained in:
Jeremy Reizenstein 2021-06-22 16:06:50 -07:00 committed by Facebook GitHub Bot
parent 2a0660baab
commit 279f4a154d

View File

@ -707,31 +707,43 @@ class TexturesUV(TexturesBase):
else:
raise ValueError("Expected verts_uvs to be a tensor or list")
if isinstance(maps, torch.Tensor):
if maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
self._maps_padded = maps
self._maps_list = None
elif isinstance(maps, (list, tuple)):
if len(maps) != self._N:
raise ValueError("Expected one texture map per mesh in the batch.")
if isinstance(maps, (list, tuple)):
self._maps_list = maps
if self._N > 0:
maps = _pad_texture_maps(maps, align_corners=self.align_corners)
else:
maps = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
)
self._maps_padded = maps
else:
raise ValueError("Expected maps to be a tensor or list.")
self._maps_list = None
self._maps_padded = self._format_maps_padded(maps)
if self._maps_padded.device != self.device:
raise ValueError("maps must be on the same device as verts/faces uvs.")
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
def _format_maps_padded(
self, maps: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
if isinstance(maps, torch.Tensor):
if maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape))
return maps
if isinstance(maps, (list, tuple)):
if len(maps) != self._N:
raise ValueError("Expected one texture map per mesh in the batch.")
if self._N > 0:
if not all(map.ndim == 3 for map in maps):
raise ValueError("Invalid number of dimensions in texture maps")
if not all(map.shape[2] == maps[0].shape[2] for map in maps):
raise ValueError("Inconsistent number of channels in maps")
maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners)
else:
maps_padded = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
)
return maps_padded
raise ValueError("Expected maps to be a tensor or list of tensors.")
def clone(self) -> "TexturesUV":
tex = self.__class__(
self.maps_padded().clone(),