make visdom optional

Summary: Make Implicitron run without visdom installed.

Reviewed By: shapovalov

Differential Revision: D40587974

fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
This commit is contained in:
Jeremy Reizenstein 2022-10-22 15:51:22 -07:00 committed by Facebook GitHub Bot
parent 46cb5aaaae
commit ff933ab82b
6 changed files with 44 additions and 22 deletions

View File

@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways:
Stats are logged and plotted to the file "train_stats.pdf" in the Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file. same directory. The stats are also saved as part of the checkpoint file.
- Visualizations - Visualizations
Prredictions are plotted to a visdom server running at the Predictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the port specified by the `visdom_server` and `visdom_port` keys in the
config file. config file.

View File

@ -9,7 +9,7 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
import torch import torch
@ -27,7 +27,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate from tabulate import tabulate
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
@ -43,7 +45,7 @@ class _Visualizer:
visdom_env: str = "eval_debug" visdom_env: str = "eval_debug"
_viz: Visdom = field(init=False) _viz: Optional["Visdom"] = field(init=False)
def __post_init__(self): def __post_init__(self):
self._viz = vis_utils.get_visdom_connection() self._viz = vis_utils.get_visdom_connection()
@ -51,6 +53,8 @@ class _Visualizer:
def show_rgb( def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
): ):
if self._viz is None:
return
self._viz.images( self._viz.images(
torch.cat( torch.cat(
( (
@ -68,7 +72,10 @@ class _Visualizer:
def show_depth( def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
): ):
self._viz.images( if self._viz is None:
return
viz = self._viz
viz.images(
torch.cat( torch.cat(
( (
make_depth_image(self.depth_render, loss_mask_now), make_depth_image(self.depth_render, loss_mask_now),
@ -80,13 +87,13 @@ class _Visualizer:
win="depth_abs" + name_postfix, win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"}, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
) )
self._viz.images( viz.images(
loss_mask_now, loss_mask_now,
env=self.visdom_env, env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask", win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"}, opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
) )
self._viz.images( viz.images(
self.depth_mask, self.depth_mask,
env=self.visdom_env, env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd", win="depth_abs" + name_postfix + "_maskd",
@ -126,7 +133,7 @@ class _Visualizer:
pointcloud_max_points=10000, pointcloud_max_points=10000,
pointcloud_marker_size=1, pointcloud_marker_size=1,
) )
self._viz.plotlyplot( viz.plotlyplot(
plotlyplot, plotlyplot,
env=self.visdom_env, env=self.visdom_env,
win=f"pcl{name_postfix}", win=f"pcl{name_postfix}",

View File

@ -12,7 +12,7 @@ import logging
import math import math
import warnings import warnings
from dataclasses import field from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch import torch
import tqdm import tqdm
@ -34,7 +34,9 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import utils as rend_utils from pytorch3d.renderer import utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
from .base_model import ImplicitronModelBase, ImplicitronRender from .base_model import ImplicitronModelBase, ImplicitronRender
from .feature_extractor import FeatureExtractorBase from .feature_extractor import FeatureExtractorBase
@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
def visualize( def visualize(
self, self,
viz: Visdom, viz: Optional["Visdom"],
visdom_env_imgs: str, visdom_env_imgs: str,
preds: Dict[str, Any], preds: Dict[str, Any],
prefix: str, prefix: str,
@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
preds: predictions dict like returned by forward() preds: predictions dict like returned by forward()
prefix: prepended to the names of images prefix: prepended to the names of images
""" """
if not viz.check_connection(): if viz is None or not viz.check_connection():
logger.info("no visdom server! -> skipping batch vis") logger.info("no visdom server! -> skipping batch vis")
return return

View File

@ -10,7 +10,7 @@ import logging
import math import math
import os import os
import random import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import numpy as np import numpy as np
import torch import torch
@ -27,7 +27,9 @@ from pytorch3d.implicitron.tools.vis_utils import (
make_depth_image, make_depth_image,
) )
from tqdm import tqdm from tqdm import tqdm
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
def _show_predictions( def _show_predictions(
preds: List[Dict[str, Any]], preds: List[Dict[str, Any]],
sequence_name: str, sequence_name: str,
viz: Visdom, viz: "Visdom",
viz_env: str = "visualizer", viz_env: str = "visualizer",
predicted_keys: Sequence[str] = ( predicted_keys: Sequence[str] = (
"images_render", "images_render",
@ -318,7 +320,7 @@ def _show_predictions(
def _generate_prediction_videos( def _generate_prediction_videos(
preds: List[Dict[str, Any]], preds: List[Dict[str, Any]],
sequence_name: str, sequence_name: str,
viz: Optional[Visdom] = None, viz: Optional["Visdom"] = None,
viz_env: str = "visualizer", viz_env: str = "visualizer",
predicted_keys: Sequence[str] = ( predicted_keys: Sequence[str] = (
"images_render", "images_render",

View File

@ -337,7 +337,7 @@ class Stats(object):
novisdom = False novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port) viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection(): if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots") print("no visdom server! -> skipping visdom plots")
novisdom = True novisdom = True

View File

@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
from typing import Any, Dict, Tuple from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import torch import torch
from visdom import Visdom
if TYPE_CHECKING:
from visdom import Visdom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,9 +42,9 @@ _viz_singleton = None
def get_visdom_connection( def get_visdom_connection(
server: str = "http://localhost", server: str = "http://localhost",
port: int = 8097, port: int = 8097,
) -> Visdom: ) -> Optional["Visdom"]:
""" """
Obtain a connection to a visdom server. Obtain a connection to a visdom server if visdom is installed.
Args: Args:
server: Server address. server: Server address.
@ -51,6 +53,15 @@ def get_visdom_connection(
Returns: Returns:
connection: The connection object. connection: The connection object.
""" """
try:
from visdom import Visdom
except ImportError:
logger.debug("Cannot load visdom")
return None
if server == "None":
return None
global _viz_singleton global _viz_singleton
if _viz_singleton is None: if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port) _viz_singleton = Visdom(server=server, port=port)
@ -58,7 +69,7 @@ def get_visdom_connection(
def visualize_basics( def visualize_basics(
viz: Visdom, viz: "Visdom",
preds: Dict[str, Any], preds: Dict[str, Any],
visdom_env_imgs: str, visdom_env_imgs: str,
title: str = "", title: str = "",