reallow scalar background color for point rendering

Summary: A scalar background color is not meant to be allowed for the point renderer. It used to be ignored with a warning, but a recent code change made it an error. It was being used, at least in the black (value=0.0) case. Re-enable it.

Reviewed By: nikhilaravi

Differential Revision: D34519651

fbshipit-source-id: d37dcf145bb7b8999c9265cf8fc39b084059dd18
This commit is contained in:
Jeremy Reizenstein 2022-03-01 05:12:55 -08:00 committed by Facebook GitHub Bot
parent 84a569c0aa
commit 69b27d160e

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -85,7 +84,10 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if not torch.is_tensor(background_color): if not torch.is_tensor(background_color):
background_color = images.new_tensor(background_color) background_color = images.new_tensor(background_color)
if len(background_color.shape) != 1: if background_color.ndim == 0:
background_color = background_color.expand(images.shape[1])
if background_color.ndim > 1:
raise ValueError("Wrong shape of background_color") raise ValueError("Wrong shape of background_color")
background_color = background_color.to(images) background_color = background_color.to(images)
@ -98,7 +100,8 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
if images.shape[1] != background_color.shape[0]: if images.shape[1] != background_color.shape[0]:
raise ValueError( raise ValueError(
f"background color has {background_color.shape[0] } channels not {images.shape[1]}" "Background color has %s channels not %s"
% (background_color.shape[0], images.shape[1])
) )
num_background_pixels = background_mask.sum() num_background_pixels = background_mask.sum()