use C for #channels in textures

Summary: Comments in textures.py were inconsistent in describing the number of channels, sometimes C, sometimes D, sometimes 3. Now always C.

Reviewed By: patricklabatut

Differential Revision: D29263435

fbshipit-source-id: 7c1260c164c52852dc9e14d0e12da4cfb64af408
This commit is contained in:
Jeremy Reizenstein 2021-06-22 16:06:50 -07:00 committed by Facebook GitHub Bot
parent c639198c97
commit c538725885

View File

@ -109,17 +109,17 @@ def _pad_texture_maps(
Pad all texture images so they have the same height and width. Pad all texture images so they have the same height and width.
Args: Args:
images: list of N tensors of shape (H_i, W_i, 3) images: list of N tensors of shape (H_i, W_i, C)
align_corners: used for interpolation align_corners: used for interpolation
Returns: Returns:
tex_maps: Tensor of shape (N, max_H, max_W, 3) tex_maps: Tensor of shape (N, max_H, max_W, C)
""" """
tex_maps = [] tex_maps = []
max_H = 0 max_H = 0
max_W = 0 max_W = 0
for im in images: for im in images:
h, w, _3 = im.shape h, w, _C = im.shape
if h > max_H: if h > max_H:
max_H = h max_H = h
if w > max_W: if w > max_W:
@ -134,7 +134,7 @@ def _pad_texture_maps(
image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners
) )
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, C)
return tex_maps return tex_maps
@ -288,12 +288,12 @@ def Textures(
Args: Args:
maps: texture map per mesh. This can either be a list of maps maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3). [(H, W, C)] or a padded tensor of shape (N, H, W, C).
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
vertex in the face. Padding value is assumed to be -1. vertex in the face. Padding value is assumed to be -1.
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex. verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding verts_rgb: (N, V, C) tensor giving the color per vertex. Padding
value is assumed to be -1. value is assumed to be -1. (C=3 for RGB.)
Returns: Returns:
@ -327,7 +327,7 @@ class TexturesAtlas(TexturesBase):
This is based on the implementation from SoftRasterizer [1]. This is based on the implementation from SoftRasterizer [1].
Args: Args:
atlas: (N, F, R, R, D) tensor giving the per face texture map. atlas: (N, F, R, R, C) tensor giving the per face texture map.
The atlas can be created during obj loading with the The atlas can be created during obj loading with the
pytorch3d.io.load_obj function - in the input arguments pytorch3d.io.load_obj function - in the input arguments
set `create_texture_atlas=True`. The atlas will be set `create_texture_atlas=True`. The atlas will be
@ -354,7 +354,7 @@ class TexturesAtlas(TexturesBase):
) )
if not correct_format: if not correct_format:
msg = ( msg = (
"Expected atlas to be a list of tensors of shape (F, R, R, D) " "Expected atlas to be a list of tensors of shape (F, R, R, C) "
"with the same value of R." "with the same value of R."
) )
raise ValueError(msg) raise ValueError(msg)
@ -373,7 +373,7 @@ class TexturesAtlas(TexturesBase):
elif torch.is_tensor(atlas): elif torch.is_tensor(atlas):
if atlas.ndim != 5: if atlas.ndim != 5:
msg = "Expected atlas to be of shape (N, F, R, R, D); got %r" msg = "Expected atlas to be of shape (N, F, R, R, C); got %r"
raise ValueError(msg % repr(atlas.ndim)) raise ValueError(msg % repr(atlas.ndim))
self._atlas_padded = atlas self._atlas_padded = atlas
self._atlas_list = None self._atlas_list = None
@ -499,7 +499,7 @@ class TexturesAtlas(TexturesBase):
representation) which overlap the pixel. representation) which overlap the pixel.
Returns: Returns:
texels: (N, H, W, K, 3) texels: (N, H, W, K, C)
""" """
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
atlas_packed = self.atlas_packed() atlas_packed = self.atlas_packed()
@ -532,7 +532,7 @@ class TexturesAtlas(TexturesBase):
""" """
Samples texture from each vertex for each face in the mesh. Samples texture from each vertex for each face in the mesh.
For N meshes with {Fi} number of faces, it returns a For N meshes with {Fi} number of faces, it returns a
tensor of shape sum(Fi)x3xD (D = 3 for RGB). tensor of shape sum(Fi)x3xC (C = 3 for RGB).
You can use the utils function in structures.utils to convert the You can use the utils function in structures.utils to convert the
packed representation to a list or padded. packed representation to a list or padded.
""" """
@ -603,7 +603,8 @@ class TexturesUV(TexturesBase):
Args: Args:
maps: texture map per mesh. This can either be a list of maps maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3) [(H, W, C)] or a padded tensor of shape (N, H, W, C).
For RGB, C = 3.
faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
for each face for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
@ -708,7 +709,7 @@ class TexturesUV(TexturesBase):
if isinstance(maps, torch.Tensor): if isinstance(maps, torch.Tensor):
if maps.ndim != 4 or maps.shape[0] != self._N: if maps.ndim != 4 or maps.shape[0] != self._N:
msg = "Expected maps to be of shape (N, H, W, 3); got %r" msg = "Expected maps to be of shape (N, H, W, C); got %r"
raise ValueError(msg % repr(maps.shape)) raise ValueError(msg % repr(maps.shape))
self._maps_padded = maps self._maps_padded = maps
self._maps_list = None self._maps_list = None
@ -1061,8 +1062,8 @@ class TexturesUV(TexturesBase):
Used by join_scene. Used by join_scene.
Args: Args:
single_map: (total_H, total_W, 3) single_map: (total_H, total_W, C)
map_: (H, W, 3) source data map_: (H, W, C) source data
location: where to place map location: where to place map
""" """
do_flip = location.flipped do_flip = location.flipped
@ -1246,10 +1247,10 @@ class TexturesVertex(TexturesBase):
): ):
""" """
Batched texture representation where each vertex in a mesh Batched texture representation where each vertex in a mesh
has a D dimensional feature vector. has a C dimensional feature vector.
Args: Args:
verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature verts_features: list of (Vi, C) or (N, V, C) tensor giving a feature
vector with arbitrary dimensions for each vertex. vector with arbitrary dimensions for each vertex.
""" """
if isinstance(verts_features, (tuple, list)): if isinstance(verts_features, (tuple, list)):
@ -1258,7 +1259,7 @@ class TexturesVertex(TexturesBase):
) )
if not correct_shape: if not correct_shape:
raise ValueError( raise ValueError(
"Expected verts_features to be a list of tensors of shape (V, D)." "Expected verts_features to be a list of tensors of shape (V, C)."
) )
self._verts_features_list = verts_features self._verts_features_list = verts_features
@ -1276,7 +1277,7 @@ class TexturesVertex(TexturesBase):
elif torch.is_tensor(verts_features): elif torch.is_tensor(verts_features):
if verts_features.ndim != 3: if verts_features.ndim != 3:
msg = "Expected verts_features to be of shape (N, V, D); got %r" msg = "Expected verts_features to be of shape (N, V, C); got %r"
raise ValueError(msg % repr(verts_features.shape)) raise ValueError(msg % repr(verts_features.shape))
self._verts_features_padded = verts_features self._verts_features_padded = verts_features
self._verts_features_list = None self._verts_features_list = None