mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
make visdom optional
Summary: Make Implicitron run without visdom installed. Reviewed By: shapovalov Differential Revision: D40587974 fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
This commit is contained in:
parent
46cb5aaaae
commit
ff933ab82b
@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
Stats are logged and plotted to the file "train_stats.pdf" in the
|
||||
same directory. The stats are also saved as part of the checkpoint file.
|
||||
- Visualizations
|
||||
Prredictions are plotted to a visdom server running at the
|
||||
Predictions are plotted to a visdom server running at the
|
||||
port specified by the `visdom_server` and `visdom_port` keys in the
|
||||
config file.
|
||||
|
||||
|
@ -9,7 +9,7 @@ import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -27,7 +27,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.vis.plotly_vis import plot_scene
|
||||
from tabulate import tabulate
|
||||
from visdom import Visdom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from visdom import Visdom
|
||||
|
||||
|
||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
||||
@ -43,7 +45,7 @@ class _Visualizer:
|
||||
|
||||
visdom_env: str = "eval_debug"
|
||||
|
||||
_viz: Visdom = field(init=False)
|
||||
_viz: Optional["Visdom"] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self._viz = vis_utils.get_visdom_connection()
|
||||
@ -51,6 +53,8 @@ class _Visualizer:
|
||||
def show_rgb(
|
||||
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
|
||||
):
|
||||
if self._viz is None:
|
||||
return
|
||||
self._viz.images(
|
||||
torch.cat(
|
||||
(
|
||||
@ -68,7 +72,10 @@ class _Visualizer:
|
||||
def show_depth(
|
||||
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
|
||||
):
|
||||
self._viz.images(
|
||||
if self._viz is None:
|
||||
return
|
||||
viz = self._viz
|
||||
viz.images(
|
||||
torch.cat(
|
||||
(
|
||||
make_depth_image(self.depth_render, loss_mask_now),
|
||||
@ -80,13 +87,13 @@ class _Visualizer:
|
||||
win="depth_abs" + name_postfix,
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
|
||||
)
|
||||
self._viz.images(
|
||||
viz.images(
|
||||
loss_mask_now,
|
||||
env=self.visdom_env,
|
||||
win="depth_abs" + name_postfix + "_mask",
|
||||
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
|
||||
)
|
||||
self._viz.images(
|
||||
viz.images(
|
||||
self.depth_mask,
|
||||
env=self.visdom_env,
|
||||
win="depth_abs" + name_postfix + "_maskd",
|
||||
@ -126,7 +133,7 @@ class _Visualizer:
|
||||
pointcloud_max_points=10000,
|
||||
pointcloud_marker_size=1,
|
||||
)
|
||||
self._viz.plotlyplot(
|
||||
viz.plotlyplot(
|
||||
plotlyplot,
|
||||
env=self.visdom_env,
|
||||
win=f"pcl{name_postfix}",
|
||||
|
@ -12,7 +12,7 @@ import logging
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
@ -34,7 +34,9 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||
from pytorch3d.renderer import utils as rend_utils
|
||||
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from visdom import Visdom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from visdom import Visdom
|
||||
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .feature_extractor import FeatureExtractorBase
|
||||
@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
viz: Visdom,
|
||||
viz: Optional["Visdom"],
|
||||
visdom_env_imgs: str,
|
||||
preds: Dict[str, Any],
|
||||
prefix: str,
|
||||
@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
preds: predictions dict like returned by forward()
|
||||
prefix: prepended to the names of images
|
||||
"""
|
||||
if not viz.check_connection():
|
||||
if viz is None or not viz.check_connection():
|
||||
logger.info("no visdom server! -> skipping batch vis")
|
||||
return
|
||||
|
||||
|
@ -10,7 +10,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -27,7 +27,9 @@ from pytorch3d.implicitron.tools.vis_utils import (
|
||||
make_depth_image,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from visdom import Visdom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from visdom import Visdom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
|
||||
def _show_predictions(
|
||||
preds: List[Dict[str, Any]],
|
||||
sequence_name: str,
|
||||
viz: Visdom,
|
||||
viz: "Visdom",
|
||||
viz_env: str = "visualizer",
|
||||
predicted_keys: Sequence[str] = (
|
||||
"images_render",
|
||||
@ -318,7 +320,7 @@ def _show_predictions(
|
||||
def _generate_prediction_videos(
|
||||
preds: List[Dict[str, Any]],
|
||||
sequence_name: str,
|
||||
viz: Optional[Visdom] = None,
|
||||
viz: Optional["Visdom"] = None,
|
||||
viz_env: str = "visualizer",
|
||||
predicted_keys: Sequence[str] = (
|
||||
"images_render",
|
||||
|
@ -337,7 +337,7 @@ class Stats(object):
|
||||
novisdom = False
|
||||
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
if not viz.check_connection():
|
||||
if viz is None or not viz.check_connection():
|
||||
print("no visdom server! -> skipping visdom plots")
|
||||
novisdom = True
|
||||
|
||||
|
@ -5,10 +5,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from visdom import Visdom
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from visdom import Visdom
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -40,9 +42,9 @@ _viz_singleton = None
|
||||
def get_visdom_connection(
|
||||
server: str = "http://localhost",
|
||||
port: int = 8097,
|
||||
) -> Visdom:
|
||||
) -> Optional["Visdom"]:
|
||||
"""
|
||||
Obtain a connection to a visdom server.
|
||||
Obtain a connection to a visdom server if visdom is installed.
|
||||
|
||||
Args:
|
||||
server: Server address.
|
||||
@ -51,6 +53,15 @@ def get_visdom_connection(
|
||||
Returns:
|
||||
connection: The connection object.
|
||||
"""
|
||||
try:
|
||||
from visdom import Visdom
|
||||
except ImportError:
|
||||
logger.debug("Cannot load visdom")
|
||||
return None
|
||||
|
||||
if server == "None":
|
||||
return None
|
||||
|
||||
global _viz_singleton
|
||||
if _viz_singleton is None:
|
||||
_viz_singleton = Visdom(server=server, port=port)
|
||||
@ -58,7 +69,7 @@ def get_visdom_connection(
|
||||
|
||||
|
||||
def visualize_basics(
|
||||
viz: Visdom,
|
||||
viz: "Visdom",
|
||||
preds: Dict[str, Any],
|
||||
visdom_env_imgs: str,
|
||||
title: str = "",
|
||||
|
Loading…
x
Reference in New Issue
Block a user