diff --git a/pytorch3d/renderer/points/compositor.py b/pytorch3d/renderer/points/compositor.py index 17483ec7..94959be4 100644 --- a/pytorch3d/renderer/points/compositor.py +++ b/pytorch3d/renderer/points/compositor.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import List, Optional, Tuple, Union import torch @@ -85,7 +84,10 @@ def _add_background_color_to_images(pix_idxs, images, background_color): if not torch.is_tensor(background_color): background_color = images.new_tensor(background_color) - if len(background_color.shape) != 1: + if background_color.ndim == 0: + background_color = background_color.expand(images.shape[1]) + + if background_color.ndim > 1: raise ValueError("Wrong shape of background_color") background_color = background_color.to(images) @@ -98,7 +100,8 @@ def _add_background_color_to_images(pix_idxs, images, background_color): if images.shape[1] != background_color.shape[0]: raise ValueError( - f"background color has {background_color.shape[0] } channels not {images.shape[1]}" + "Background color has %s channels not %s" + % (background_color.shape[0], images.shape[1]) ) num_background_pixels = background_mask.sum()