make visdom optional

Summary: Make Implicitron run without visdom installed.

Reviewed By: shapovalov

Differential Revision: D40587974

fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
This commit is contained in:
Jeremy Reizenstein 2022-10-22 15:51:22 -07:00 committed by Facebook GitHub Bot
parent 46cb5aaaae
commit ff933ab82b
6 changed files with 44 additions and 22 deletions

View File

@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways:
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
Predictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.

View File

@ -9,7 +9,7 @@ import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np
import torch
@ -27,6 +27,8 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate
if TYPE_CHECKING:
from visdom import Visdom
@ -43,7 +45,7 @@ class _Visualizer:
visdom_env: str = "eval_debug"
_viz: Visdom = field(init=False)
_viz: Optional["Visdom"] = field(init=False)
def __post_init__(self):
self._viz = vis_utils.get_visdom_connection()
@ -51,6 +53,8 @@ class _Visualizer:
def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
):
if self._viz is None:
return
self._viz.images(
torch.cat(
(
@ -68,7 +72,10 @@ class _Visualizer:
def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
):
self._viz.images(
if self._viz is None:
return
viz = self._viz
viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
@ -80,13 +87,13 @@ class _Visualizer:
win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
)
self._viz.images(
viz.images(
loss_mask_now,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
self._viz.images(
viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
@ -126,7 +133,7 @@ class _Visualizer:
pointcloud_max_points=10000,
pointcloud_marker_size=1,
)
self._viz.plotlyplot(
viz.plotlyplot(
plotlyplot,
env=self.visdom_env,
win=f"pcl{name_postfix}",

View File

@ -12,7 +12,7 @@ import logging
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
import tqdm
@ -34,6 +34,8 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase
if TYPE_CHECKING:
from visdom import Visdom
from .base_model import ImplicitronModelBase, ImplicitronRender
@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
def visualize(
self,
viz: Visdom,
viz: Optional["Visdom"],
visdom_env_imgs: str,
preds: Dict[str, Any],
prefix: str,
@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
preds: predictions dict like returned by forward()
prefix: prepended to the names of images
"""
if not viz.check_connection():
if viz is None or not viz.check_connection():
logger.info("no visdom server! -> skipping batch vis")
return

View File

@ -10,7 +10,7 @@ import logging
import math
import os
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np
import torch
@ -27,6 +27,8 @@ from pytorch3d.implicitron.tools.vis_utils import (
make_depth_image,
)
from tqdm import tqdm
if TYPE_CHECKING:
from visdom import Visdom
logger = logging.getLogger(__name__)
@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
def _show_predictions(
preds: List[Dict[str, Any]],
sequence_name: str,
viz: Visdom,
viz: "Visdom",
viz_env: str = "visualizer",
predicted_keys: Sequence[str] = (
"images_render",
@ -318,7 +320,7 @@ def _show_predictions(
def _generate_prediction_videos(
preds: List[Dict[str, Any]],
sequence_name: str,
viz: Optional[Visdom] = None,
viz: Optional["Visdom"] = None,
viz_env: str = "visualizer",
predicted_keys: Sequence[str] = (
"images_render",

View File

@ -337,7 +337,7 @@ class Stats(object):
novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection():
if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True

View File

@ -5,9 +5,11 @@
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from visdom import Visdom
@ -40,9 +42,9 @@ _viz_singleton = None
def get_visdom_connection(
server: str = "http://localhost",
port: int = 8097,
) -> Visdom:
) -> Optional["Visdom"]:
"""
Obtain a connection to a visdom server.
Obtain a connection to a visdom server if visdom is installed.
Args:
server: Server address.
@ -51,6 +53,15 @@ def get_visdom_connection(
Returns:
connection: The connection object.
"""
try:
from visdom import Visdom
except ImportError:
logger.debug("Cannot load visdom")
return None
if server == "None":
return None
global _viz_singleton
if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port)
@ -58,7 +69,7 @@ def get_visdom_connection(
def visualize_basics(
viz: Visdom,
viz: "Visdom",
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",