Jeremy Reizenstein 199309fcf7 logging
Summary: Use logging instead of printing in the internals of implicitron.

Reviewed By: davnov134

Differential Revision: D35247581

fbshipit-source-id: be5ddad5efe1409adbae0575d35ade6112b3be63
2022-04-04 06:53:16 -07:00

177 lines
5.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, List
import torch
from visdom import Visdom
logger = logging.getLogger(__name__)
def get_visdom_env(cfg):
"""
Parse out visdom environment name from the input config.
Args:
cfg: The global config file.
Returns:
visdom_env: The name of the visdom environment.
"""
if len(cfg.visdom_env) == 0:
visdom_env = cfg.exp_dir.split("/")[-1]
else:
visdom_env = cfg.visdom_env
return visdom_env
# TODO: a proper singleton
_viz_singleton = None
def get_visdom_connection(
server: str = "http://localhost",
port: int = 8097,
) -> Visdom:
"""
Obtain a connection to a visdom server.
Args:
server: Server address.
port: Server port.
Returns:
connection: The connection object.
"""
global _viz_singleton
if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port)
return _viz_singleton
def visualize_basics(
viz: Visdom,
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
visualize_preds_keys: List[str] = [
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
],
store_history: bool = False,
) -> None:
"""
Visualize basic outputs of a `GenericModel` to visdom.
Args:
viz: The visdom object.
preds: A dictionary containing `GenericModel` outputs.
visdom_env_imgs: Target visdom environment name.
title: The title of produced visdom window.
visualize_preds_keys: The list of keys of `preds` for visualization.
store_history: Store the history buffer in visdom windows.
"""
imout = {}
for k in visualize_preds_keys:
if k not in preds or preds[k] is None:
logger.info(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
# divide by 95th percentile
normfac = (
v.view(v.shape[0], -1)
.topk(k=int(0.05 * (v.numel() // v.shape[0])), dim=-1)
.values[:, -1]
)
v = v / normfac[:, None, None, None].clamp(1e-4)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
v = torch.nn.functional.interpolate(
v,
# pyre-fixme[6]: Expected `Optional[typing.List[float]]` for 2nd param
# but got `float`.
scale_factor=(
600.0
if (
"_eval" in visdom_env_imgs
and k in ("images_render", "depths_render")
)
else 200.0
)
/ v.shape[2],
mode="bilinear",
)
imout[k] = v
# TODO: handle errors on the outside
try:
imout = {"all": torch.cat(list(imout.values()), dim=2)}
except:
print("cant cat!")
for k, v in imout.items():
viz.images(
v.clamp(0.0, 1.0),
win=k,
env=visdom_env_imgs,
opts={"title": title + "_" + k, "store_history": store_history},
)
def make_depth_image(
depths: torch.Tensor,
masks: torch.Tensor,
max_quantile: float = 0.98,
min_quantile: float = 0.02,
min_out_depth: float = 0.1,
max_out_depth: float = 0.9,
) -> torch.Tensor:
"""
Convert a batch of depth maps to a grayscale image.
Args:
depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks.
max_quantile: The quantile of the input depth values which will
be mapped to `max_out_depth`.
min_quantile: The quantile of the input depth values which will
be mapped to `min_out_depth`.
min_out_depth: The minimal value in each depth map will be assigned this color.
max_out_depth: The maximal value in each depth map will be assigned this color.
Returns:
depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale
depth images.
"""
normfacs = []
for d, m in zip(depths, masks):
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
if ok.sum() <= 1:
logger.info("empty depth!")
normfacs.append(torch.zeros(2).type_as(depths))
continue
dok = d.view(-1)[ok].view(-1)
_maxk = max(int(round((1 - max_quantile) * (dok.numel()))), 1)
_mink = max(int(round(min_quantile * (dok.numel()))), 1)
normfac_max = dok.topk(k=_maxk, dim=-1).values[-1]
normfac_min = dok.topk(k=_mink, dim=-1, largest=False).values[-1]
normfacs.append(torch.stack([normfac_min, normfac_max]))
normfacs = torch.stack(normfacs)
_min, _max = (normfacs[:, 0].view(-1, 1, 1, 1), normfacs[:, 1].view(-1, 1, 1, 1))
depths = (depths - _min) / (_max - _min).clamp(1e-4)
depths = (
(depths * (max_out_depth - min_out_depth) + min_out_depth) * masks.float()
).clamp(0.0, 1.0)
return depths