mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Softly deprecate the get_str=False flag.
Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller. Reviewed By: shapovalov Differential Revision: D45356240 fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
This commit is contained in:
		
							parent
							
								
									297020a4b1
								
							
						
					
					
						commit
						d08fe6d45a
					
				@ -256,7 +256,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
 | 
			
		||||
            list(log_vars),
 | 
			
		||||
            plot_file=os.path.join(exp_dir, "train_stats.pdf"),
 | 
			
		||||
            visdom_env=visdom_env_charts,
 | 
			
		||||
            verbose=False,
 | 
			
		||||
            visdom_server=self.visdom_server,
 | 
			
		||||
            visdom_port=self.visdom_port,
 | 
			
		||||
        )
 | 
			
		||||
@ -382,7 +381,8 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
 | 
			
		||||
 | 
			
		||||
            # print textual status update
 | 
			
		||||
            if it % self.metric_print_interval == 0 or last_iter:
 | 
			
		||||
                stats.print(stat_set=trainmode, max_it=n_batches)
 | 
			
		||||
                std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
 | 
			
		||||
                logger.info(std_out)
 | 
			
		||||
 | 
			
		||||
            # visualize results
 | 
			
		||||
            if (
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@
 | 
			
		||||
 | 
			
		||||
import gzip
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import time
 | 
			
		||||
import warnings
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
@ -17,6 +18,8 @@ import numpy as np
 | 
			
		||||
from matplotlib import colors as mcolors
 | 
			
		||||
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AverageMeter(object):
 | 
			
		||||
    """Computes and stores the average and current value"""
 | 
			
		||||
@ -91,7 +94,9 @@ class Stats(object):
 | 
			
		||||
                # stats.update() automatically parses the 'objective' and 'top1e' from
 | 
			
		||||
                # the "output" dict and stores this into the db
 | 
			
		||||
                stats.update(output)
 | 
			
		||||
                stats.print() # prints the averages over given epoch
 | 
			
		||||
                # prints the metric averages over given epoch
 | 
			
		||||
                std_out = stats.get_status_string()
 | 
			
		||||
                logger.info(str_out)
 | 
			
		||||
            # stores the training plots into '/tmp/epoch_stats.pdf'
 | 
			
		||||
            # and plots into a visdom server running at localhost (if running)
 | 
			
		||||
            stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
 | 
			
		||||
@ -101,7 +106,6 @@ class Stats(object):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        log_vars,
 | 
			
		||||
        verbose=False,
 | 
			
		||||
        epoch=-1,
 | 
			
		||||
        visdom_env="main",
 | 
			
		||||
        do_plot=True,
 | 
			
		||||
@ -110,7 +114,6 @@ class Stats(object):
 | 
			
		||||
        visdom_port=8097,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
        self.verbose = verbose
 | 
			
		||||
        self.log_vars = log_vars
 | 
			
		||||
        self.visdom_env = visdom_env
 | 
			
		||||
        self.visdom_server = visdom_server
 | 
			
		||||
@ -156,15 +159,14 @@ class Stats(object):
 | 
			
		||||
        iserr = type is not None and issubclass(type, Exception)
 | 
			
		||||
        iserr = iserr or (type is KeyboardInterrupt)
 | 
			
		||||
        if iserr:
 | 
			
		||||
            print("error inside 'with' block")
 | 
			
		||||
            logger.error("error inside 'with' block")
 | 
			
		||||
            return
 | 
			
		||||
        if self.do_plot:
 | 
			
		||||
            self.plot_stats(self.visdom_env)
 | 
			
		||||
 | 
			
		||||
    def reset(self):  # to be called after each epoch
 | 
			
		||||
        stat_sets = list(self.stats.keys())
 | 
			
		||||
        if self.verbose:
 | 
			
		||||
            print("stats: epoch %d - reset" % self.epoch)
 | 
			
		||||
        logger.debug(f"stats: epoch {self.epoch} - reset")
 | 
			
		||||
        self.it = {k: -1 for k in stat_sets}
 | 
			
		||||
        for stat_set in stat_sets:
 | 
			
		||||
            for stat in self.stats[stat_set]:
 | 
			
		||||
@ -172,16 +174,14 @@ class Stats(object):
 | 
			
		||||
 | 
			
		||||
    def hard_reset(self, epoch=-1):  # to be called during object __init__
 | 
			
		||||
        self.epoch = epoch
 | 
			
		||||
        if self.verbose:
 | 
			
		||||
            print("stats: epoch %d - hard reset" % self.epoch)
 | 
			
		||||
        logger.debug(f"stats: epoch {self.epoch} - hard reset")
 | 
			
		||||
        self.stats = {}
 | 
			
		||||
 | 
			
		||||
        # reset
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    def new_epoch(self):
 | 
			
		||||
        if self.verbose:
 | 
			
		||||
            print("stats: new epoch %d" % (self.epoch + 1))
 | 
			
		||||
        logger.debug(f"stats: new epoch {(self.epoch + 1)}")
 | 
			
		||||
        self.epoch += 1
 | 
			
		||||
        self.reset()  # zero the stats + increase epoch counter
 | 
			
		||||
 | 
			
		||||
@ -193,18 +193,17 @@ class Stats(object):
 | 
			
		||||
            val = float(val.sum())
 | 
			
		||||
        return val
 | 
			
		||||
 | 
			
		||||
    def add_log_vars(self, added_log_vars, verbose=True):
 | 
			
		||||
    def add_log_vars(self, added_log_vars):
 | 
			
		||||
        for add_log_var in added_log_vars:
 | 
			
		||||
            if add_log_var not in self.stats:
 | 
			
		||||
                if verbose:
 | 
			
		||||
                    print(f"Adding {add_log_var}")
 | 
			
		||||
                logger.debug(f"Adding {add_log_var}")
 | 
			
		||||
                self.log_vars.append(add_log_var)
 | 
			
		||||
 | 
			
		||||
    def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
 | 
			
		||||
 | 
			
		||||
        if self.epoch == -1:  # uninitialized
 | 
			
		||||
            print(
 | 
			
		||||
                "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "epoch==-1 means uninitialized stats structure -> new_epoch() called"
 | 
			
		||||
            )
 | 
			
		||||
            self.new_epoch()
 | 
			
		||||
 | 
			
		||||
@ -284,6 +283,12 @@ class Stats(object):
 | 
			
		||||
        skip_nan=False,
 | 
			
		||||
        stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        stats.print() is deprecated. Please use get_status_string() instead.
 | 
			
		||||
        example:
 | 
			
		||||
        std_out = stats.get_status_string()
 | 
			
		||||
        logger.info(str_out)
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        epoch = self.epoch
 | 
			
		||||
        stats = self.stats
 | 
			
		||||
@ -311,8 +316,30 @@ class Stats(object):
 | 
			
		||||
        if get_str:
 | 
			
		||||
            return str_out
 | 
			
		||||
        else:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "get_str=False is deprecated."
 | 
			
		||||
                "Please enable this flag to get receive the output string.",
 | 
			
		||||
                DeprecationWarning,
 | 
			
		||||
            )
 | 
			
		||||
            print(str_out)
 | 
			
		||||
 | 
			
		||||
    def get_status_string(
 | 
			
		||||
        self,
 | 
			
		||||
        max_it=None,
 | 
			
		||||
        stat_set="train",
 | 
			
		||||
        vars_print=None,
 | 
			
		||||
        skip_nan=False,
 | 
			
		||||
        stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
 | 
			
		||||
    ):
 | 
			
		||||
        return self.print(
 | 
			
		||||
            max_it=max_it,
 | 
			
		||||
            stat_set=stat_set,
 | 
			
		||||
            vars_print=vars_print,
 | 
			
		||||
            get_str=True,
 | 
			
		||||
            skip_nan=skip_nan,
 | 
			
		||||
            stat_format=stat_format,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def plot_stats(
 | 
			
		||||
        self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
 | 
			
		||||
    ):
 | 
			
		||||
@ -329,16 +356,15 @@ class Stats(object):
 | 
			
		||||
 | 
			
		||||
        stat_sets = list(self.stats.keys())
 | 
			
		||||
 | 
			
		||||
        print(
 | 
			
		||||
            "printing charts to visdom env '%s' (%s:%d)"
 | 
			
		||||
            % (visdom_env, visdom_server, visdom_port)
 | 
			
		||||
        logger.debug(
 | 
			
		||||
            f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        novisdom = False
 | 
			
		||||
 | 
			
		||||
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
        if viz is None or not viz.check_connection():
 | 
			
		||||
            print("no visdom server! -> skipping visdom plots")
 | 
			
		||||
            logger.info("no visdom server! -> skipping visdom plots")
 | 
			
		||||
            novisdom = True
 | 
			
		||||
 | 
			
		||||
        lines = []
 | 
			
		||||
@ -385,7 +411,7 @@ class Stats(object):
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        if plot_file:
 | 
			
		||||
            print("exporting stats to %s" % plot_file)
 | 
			
		||||
            logger.info(f"plotting stats to {plot_file}")
 | 
			
		||||
            ncol = 3
 | 
			
		||||
            nrow = int(np.ceil(float(len(lines)) / ncol))
 | 
			
		||||
            matplotlib.rcParams.update({"font.size": 5})
 | 
			
		||||
@ -423,7 +449,7 @@ class Stats(object):
 | 
			
		||||
            except PermissionError:
 | 
			
		||||
                warnings.warn("Cant dump stats due to insufficient permissions!")
 | 
			
		||||
 | 
			
		||||
    def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
 | 
			
		||||
    def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
 | 
			
		||||
 | 
			
		||||
        stat_sets = list(self.stats.keys())
 | 
			
		||||
 | 
			
		||||
@ -431,7 +457,7 @@ class Stats(object):
 | 
			
		||||
        for stat_set in stat_sets:
 | 
			
		||||
            for stat in self.stats[stat_set].keys():
 | 
			
		||||
                if stat not in log_vars:
 | 
			
		||||
                    print("additional stat %s:%s -> removing" % (stat_set, stat))
 | 
			
		||||
                    logger.warning(f"additional stat {stat_set}:{stat} -> removing")
 | 
			
		||||
 | 
			
		||||
            self.stats[stat_set] = {
 | 
			
		||||
                stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
 | 
			
		||||
@ -442,21 +468,19 @@ class Stats(object):
 | 
			
		||||
        for stat_set in stat_sets:
 | 
			
		||||
            for stat in log_vars:
 | 
			
		||||
                if stat not in self.stats[stat_set]:
 | 
			
		||||
                    if verbose:
 | 
			
		||||
                        print(
 | 
			
		||||
                            "missing stat %s:%s -> filling with default values (%1.2f)"
 | 
			
		||||
                            % (stat_set, stat, default_val)
 | 
			
		||||
                        )
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        "missing stat %s:%s -> filling with default values (%1.2f)"
 | 
			
		||||
                        % (stat_set, stat, default_val)
 | 
			
		||||
                    )
 | 
			
		||||
                elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
 | 
			
		||||
                    h = self.stats[stat_set][stat].history
 | 
			
		||||
                    if len(h) == 0:  # just never updated stat ... skip
 | 
			
		||||
                        continue
 | 
			
		||||
                    else:
 | 
			
		||||
                        if verbose:
 | 
			
		||||
                            print(
 | 
			
		||||
                                "incomplete stat %s:%s -> reseting with default values (%1.2f)"
 | 
			
		||||
                                % (stat_set, stat, default_val)
 | 
			
		||||
                            )
 | 
			
		||||
                        logger.info(
 | 
			
		||||
                            "incomplete stat %s:%s -> reseting with default values (%1.2f)"
 | 
			
		||||
                            % (stat_set, stat, default_val)
 | 
			
		||||
                        )
 | 
			
		||||
                else:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user