diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 2dabbd08..351a142a 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -707,31 +707,43 @@ class TexturesUV(TexturesBase): else: raise ValueError("Expected verts_uvs to be a tensor or list") - if isinstance(maps, torch.Tensor): - if maps.ndim != 4 or maps.shape[0] != self._N: - msg = "Expected maps to be of shape (N, H, W, C); got %r" - raise ValueError(msg % repr(maps.shape)) - self._maps_padded = maps - self._maps_list = None - elif isinstance(maps, (list, tuple)): - if len(maps) != self._N: - raise ValueError("Expected one texture map per mesh in the batch.") + if isinstance(maps, (list, tuple)): self._maps_list = maps - if self._N > 0: - maps = _pad_texture_maps(maps, align_corners=self.align_corners) - else: - maps = torch.empty( - (self._N, 0, 0, 3), dtype=torch.float32, device=self.device - ) - self._maps_padded = maps else: - raise ValueError("Expected maps to be a tensor or list.") + self._maps_list = None + self._maps_padded = self._format_maps_padded(maps) if self._maps_padded.device != self.device: raise ValueError("maps must be on the same device as verts/faces uvs.") self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) + def _format_maps_padded( + self, maps: Union[torch.Tensor, List[torch.Tensor]] + ) -> torch.Tensor: + if isinstance(maps, torch.Tensor): + if maps.ndim != 4 or maps.shape[0] != self._N: + msg = "Expected maps to be of shape (N, H, W, C); got %r" + raise ValueError(msg % repr(maps.shape)) + return maps + + if isinstance(maps, (list, tuple)): + if len(maps) != self._N: + raise ValueError("Expected one texture map per mesh in the batch.") + if self._N > 0: + if not all(map.ndim == 3 for map in maps): + raise ValueError("Invalid number of dimensions in texture maps") + if not all(map.shape[2] == maps[0].shape[2] for map in maps): + raise ValueError("Inconsistent number of channels in maps") + maps_padded = _pad_texture_maps(maps, align_corners=self.align_corners) + else: + maps_padded = torch.empty( + (self._N, 0, 0, 3), dtype=torch.float32, device=self.device + ) + return maps_padded + + raise ValueError("Expected maps to be a tensor or list of tensors.") + def clone(self) -> "TexturesUV": tex = self.__class__( self.maps_padded().clone(),