mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	make visdom optional
Summary: Make Implicitron run without visdom installed. Reviewed By: shapovalov Differential Revision: D40587974 fbshipit-source-id: dc319596c7a4d10a4c54c556dabc89ad9d25c2fb
This commit is contained in:
		
							parent
							
								
									46cb5aaaae
								
							
						
					
					
						commit
						ff933ab82b
					
				@ -41,7 +41,7 @@ The outputs of the experiment are saved and logged in multiple ways:
 | 
			
		||||
        Stats are logged and plotted to the file "train_stats.pdf" in the
 | 
			
		||||
        same directory. The stats are also saved as part of the checkpoint file.
 | 
			
		||||
  - Visualizations
 | 
			
		||||
        Prredictions are plotted to a visdom server running at the
 | 
			
		||||
        Predictions are plotted to a visdom server running at the
 | 
			
		||||
        port specified by the `visdom_server` and `visdom_port` keys in the
 | 
			
		||||
        config file.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ import copy
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -27,7 +27,9 @@ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.vis.plotly_vis import plot_scene
 | 
			
		||||
from tabulate import tabulate
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
 | 
			
		||||
@ -43,7 +45,7 @@ class _Visualizer:
 | 
			
		||||
 | 
			
		||||
    visdom_env: str = "eval_debug"
 | 
			
		||||
 | 
			
		||||
    _viz: Visdom = field(init=False)
 | 
			
		||||
    _viz: Optional["Visdom"] = field(init=False)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self._viz = vis_utils.get_visdom_connection()
 | 
			
		||||
@ -51,6 +53,8 @@ class _Visualizer:
 | 
			
		||||
    def show_rgb(
 | 
			
		||||
        self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
 | 
			
		||||
    ):
 | 
			
		||||
        if self._viz is None:
 | 
			
		||||
            return
 | 
			
		||||
        self._viz.images(
 | 
			
		||||
            torch.cat(
 | 
			
		||||
                (
 | 
			
		||||
@ -68,7 +72,10 @@ class _Visualizer:
 | 
			
		||||
    def show_depth(
 | 
			
		||||
        self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
 | 
			
		||||
    ):
 | 
			
		||||
        self._viz.images(
 | 
			
		||||
        if self._viz is None:
 | 
			
		||||
            return
 | 
			
		||||
        viz = self._viz
 | 
			
		||||
        viz.images(
 | 
			
		||||
            torch.cat(
 | 
			
		||||
                (
 | 
			
		||||
                    make_depth_image(self.depth_render, loss_mask_now),
 | 
			
		||||
@ -80,13 +87,13 @@ class _Visualizer:
 | 
			
		||||
            win="depth_abs" + name_postfix,
 | 
			
		||||
            opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
 | 
			
		||||
        )
 | 
			
		||||
        self._viz.images(
 | 
			
		||||
        viz.images(
 | 
			
		||||
            loss_mask_now,
 | 
			
		||||
            env=self.visdom_env,
 | 
			
		||||
            win="depth_abs" + name_postfix + "_mask",
 | 
			
		||||
            opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
 | 
			
		||||
        )
 | 
			
		||||
        self._viz.images(
 | 
			
		||||
        viz.images(
 | 
			
		||||
            self.depth_mask,
 | 
			
		||||
            env=self.visdom_env,
 | 
			
		||||
            win="depth_abs" + name_postfix + "_maskd",
 | 
			
		||||
@ -126,7 +133,7 @@ class _Visualizer:
 | 
			
		||||
            pointcloud_max_points=10000,
 | 
			
		||||
            pointcloud_marker_size=1,
 | 
			
		||||
        )
 | 
			
		||||
        self._viz.plotlyplot(
 | 
			
		||||
        viz.plotlyplot(
 | 
			
		||||
            plotlyplot,
 | 
			
		||||
            env=self.visdom_env,
 | 
			
		||||
            win=f"pcl{name_postfix}",
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import field
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
@ -34,7 +34,9 @@ from pytorch3d.implicitron.tools.utils import cat_dataclass
 | 
			
		||||
from pytorch3d.renderer import utils as rend_utils
 | 
			
		||||
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
from .base_model import ImplicitronModelBase, ImplicitronRender
 | 
			
		||||
from .feature_extractor import FeatureExtractorBase
 | 
			
		||||
@ -544,7 +546,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
 | 
			
		||||
    def visualize(
 | 
			
		||||
        self,
 | 
			
		||||
        viz: Visdom,
 | 
			
		||||
        viz: Optional["Visdom"],
 | 
			
		||||
        visdom_env_imgs: str,
 | 
			
		||||
        preds: Dict[str, Any],
 | 
			
		||||
        prefix: str,
 | 
			
		||||
@ -559,7 +561,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            preds: predictions dict like returned by forward()
 | 
			
		||||
            prefix: prepended to the names of images
 | 
			
		||||
        """
 | 
			
		||||
        if not viz.check_connection():
 | 
			
		||||
        if viz is None or not viz.check_connection():
 | 
			
		||||
            logger.info("no visdom server! -> skipping batch vis")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,7 @@ import logging
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -27,7 +27,9 @@ from pytorch3d.implicitron.tools.vis_utils import (
 | 
			
		||||
    make_depth_image,
 | 
			
		||||
)
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -272,7 +274,7 @@ def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.T
 | 
			
		||||
def _show_predictions(
 | 
			
		||||
    preds: List[Dict[str, Any]],
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    viz: Visdom,
 | 
			
		||||
    viz: "Visdom",
 | 
			
		||||
    viz_env: str = "visualizer",
 | 
			
		||||
    predicted_keys: Sequence[str] = (
 | 
			
		||||
        "images_render",
 | 
			
		||||
@ -318,7 +320,7 @@ def _show_predictions(
 | 
			
		||||
def _generate_prediction_videos(
 | 
			
		||||
    preds: List[Dict[str, Any]],
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    viz: Optional[Visdom] = None,
 | 
			
		||||
    viz: Optional["Visdom"] = None,
 | 
			
		||||
    viz_env: str = "visualizer",
 | 
			
		||||
    predicted_keys: Sequence[str] = (
 | 
			
		||||
        "images_render",
 | 
			
		||||
 | 
			
		||||
@ -337,7 +337,7 @@ class Stats(object):
 | 
			
		||||
        novisdom = False
 | 
			
		||||
 | 
			
		||||
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
        if not viz.check_connection():
 | 
			
		||||
        if viz is None or not viz.check_connection():
 | 
			
		||||
            print("no visdom server! -> skipping visdom plots")
 | 
			
		||||
            novisdom = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,12 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Any, Dict, Tuple
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -40,9 +42,9 @@ _viz_singleton = None
 | 
			
		||||
def get_visdom_connection(
 | 
			
		||||
    server: str = "http://localhost",
 | 
			
		||||
    port: int = 8097,
 | 
			
		||||
) -> Visdom:
 | 
			
		||||
) -> Optional["Visdom"]:
 | 
			
		||||
    """
 | 
			
		||||
    Obtain a connection to a visdom server.
 | 
			
		||||
    Obtain a connection to a visdom server if visdom is installed.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        server: Server address.
 | 
			
		||||
@ -51,6 +53,15 @@ def get_visdom_connection(
 | 
			
		||||
    Returns:
 | 
			
		||||
        connection: The connection object.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        from visdom import Visdom
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        logger.debug("Cannot load visdom")
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    if server == "None":
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    global _viz_singleton
 | 
			
		||||
    if _viz_singleton is None:
 | 
			
		||||
        _viz_singleton = Visdom(server=server, port=port)
 | 
			
		||||
@ -58,7 +69,7 @@ def get_visdom_connection(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def visualize_basics(
 | 
			
		||||
    viz: Visdom,
 | 
			
		||||
    viz: "Visdom",
 | 
			
		||||
    preds: Dict[str, Any],
 | 
			
		||||
    visdom_env_imgs: str,
 | 
			
		||||
    title: str = "",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user