# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import gzip import json import time import warnings from collections.abc import Iterable from itertools import cycle import matplotlib import matplotlib.pyplot as plt import numpy as np from matplotlib import colors as mcolors from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.history = [] self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1, epoch=0): # make sure the history is of the same len as epoch while len(self.history) <= epoch: self.history.append([]) self.history[epoch].append(val / n) self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def get_epoch_averages(self, epoch=-1): if len(self.history) == 0: # no stats here return None elif epoch == -1: return [ (float(np.array(x).mean()) if len(x) > 0 else float("NaN")) for x in self.history ] else: return float(np.array(self.history[epoch]).mean()) def get_all_values(self): all_vals = [np.array(x) for x in self.history] all_vals = np.concatenate(all_vals) return all_vals def get_epoch(self): return len(self.history) @staticmethod def from_json_str(json_str): self = AverageMeter() self.__dict__.update(json.loads(json_str)) return self class Stats(object): # TODO: update this with context manager """ stats logging object useful for gathering statistics of training a deep net in pytorch Example: # init stats structure that logs statistics 'objective' and 'top1e' stats = Stats( ('objective','top1e') ) network = init_net() # init a pytorch module (=nueral network) dataloader = init_dataloader() # init a dataloader for epoch in range(10): # start of epoch -> call new_epoch stats.new_epoch() # iterate over batches for batch in dataloader: output = network(batch) # run and save into a dict of output variables "output" # stats.update() automatically parses the 'objective' and 'top1e' from # the "output" dict and stores this into the db stats.update(output) stats.print() # prints the averages over given epoch # stores the training plots into '/tmp/epoch_stats.pdf' # and plots into a visdom server running at localhost (if running) stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') """ def __init__( self, log_vars, verbose=False, epoch=-1, visdom_env="main", do_plot=True, plot_file=None, visdom_server="http://localhost", visdom_port=8097, ): self.verbose = verbose self.log_vars = log_vars self.visdom_env = visdom_env self.visdom_server = visdom_server self.visdom_port = visdom_port self.plot_file = plot_file self.do_plot = do_plot self.hard_reset(epoch=epoch) @staticmethod def from_json_str(json_str): self = Stats([]) # load the global state self.__dict__.update(json.loads(json_str)) # recover the AverageMeters for stat_set in self.stats: self.stats[stat_set] = { log_var: AverageMeter.from_json_str(log_vals_json_str) for log_var, log_vals_json_str in self.stats[stat_set].items() } return self @staticmethod def load(flpath, postfix=".jgz"): flpath = _get_postfixed_filename(flpath, postfix) with gzip.open(flpath, "r") as fin: data = json.loads(fin.read().decode("utf-8")) return Stats.from_json_str(data) def save(self, flpath, postfix=".jgz"): flpath = _get_postfixed_filename(flpath, postfix) # store into a gzipped-json with gzip.open(flpath, "w") as fout: fout.write(json.dumps(self, cls=StatsJSONEncoder).encode("utf-8")) # some sugar to be used with "with stats:" at the beginning of the epoch def __enter__(self): if self.do_plot and self.epoch >= 0: self.plot_stats(self.visdom_env) self.new_epoch() def __exit__(self, type, value, traceback): iserr = type is not None and issubclass(type, Exception) iserr = iserr or (type is KeyboardInterrupt) if iserr: print("error inside 'with' block") return if self.do_plot: self.plot_stats(self.visdom_env) def reset(self): # to be called after each epoch stat_sets = list(self.stats.keys()) if self.verbose: print("stats: epoch %d - reset" % self.epoch) self.it = {k: -1 for k in stat_sets} for stat_set in stat_sets: for stat in self.stats[stat_set]: self.stats[stat_set][stat].reset() def hard_reset(self, epoch=-1): # to be called during object __init__ self.epoch = epoch if self.verbose: print("stats: epoch %d - hard reset" % self.epoch) self.stats = {} # reset self.reset() def new_epoch(self): if self.verbose: print("stats: new epoch %d" % (self.epoch + 1)) self.epoch += 1 self.reset() # zero the stats + increase epoch counter def gather_value(self, val): if isinstance(val, (float, int)): val = float(val) else: val = val.data.cpu().numpy() val = float(val.sum()) return val def add_log_vars(self, added_log_vars, verbose=True): for add_log_var in added_log_vars: if add_log_var not in self.stats: if verbose: print(f"Adding {add_log_var}") self.log_vars.append(add_log_var) # self.synchronize_logged_vars(self.log_vars, verbose=verbose) def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): if self.epoch == -1: # uninitialized print( "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called" ) self.new_epoch() if stat_set not in self.stats: self.stats[stat_set] = {} self.it[stat_set] = -1 if not freeze_iter: self.it[stat_set] += 1 epoch = self.epoch it = self.it[stat_set] for stat in self.log_vars: if stat not in self.stats[stat_set]: self.stats[stat_set][stat] = AverageMeter() if stat == "sec/it": # compute speed if time_start is None: elapsed = 0.0 else: elapsed = time.time() - time_start time_per_it = float(elapsed) / float(it + 1) val = time_per_it # self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1) else: if stat in preds: try: val = self.gather_value(preds[stat]) except KeyError: raise ValueError( "could not extract prediction %s\ from the prediction dictionary" % stat ) from None else: val = None if val is not None: self.stats[stat_set][stat].update(val, epoch=epoch, n=1) def get_epoch_averages(self, epoch=None): stat_sets = list(self.stats.keys()) if epoch is None: epoch = self.epoch if epoch == -1: epoch = list(range(self.epoch)) outvals = {} for stat_set in stat_sets: outvals[stat_set] = { "epoch": epoch, "it": self.it[stat_set], "epoch_max": self.epoch, } for stat in self.stats[stat_set].keys(): if self.stats[stat_set][stat].count == 0: continue if isinstance(epoch, Iterable): avgs = self.stats[stat_set][stat].get_epoch_averages() avgs = [avgs[e] for e in epoch] else: avgs = self.stats[stat_set][stat].get_epoch_averages(epoch=epoch) outvals[stat_set][stat] = avgs return outvals def print( self, max_it=None, stat_set="train", vars_print=None, get_str=False, skip_nan=False, stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"), ): epoch = self.epoch stats = self.stats str_out = "" it = self.it[stat_set] stat_str = "" stats_print = sorted(stats[stat_set].keys()) for stat in stats_print: if stats[stat_set][stat].count == 0: continue if skip_nan and not np.isfinite(stats[stat_set][stat].avg): continue stat_str += " {0:.12}: {1:1.3f} |".format( stat_format(stat), stats[stat_set][stat].avg ) head_str = "[%s] | epoch %3d | it %5d" % (stat_set, epoch, it) if max_it: head_str += "/ %d" % max_it str_out = "%s | %s" % (head_str, stat_str) if get_str: return str_out else: print(str_out) def plot_stats( self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None ): # use the cached visdom env if none supplied if visdom_env is None: visdom_env = self.visdom_env if visdom_server is None: visdom_server = self.visdom_server if visdom_port is None: visdom_port = self.visdom_port if plot_file is None: plot_file = self.plot_file stat_sets = list(self.stats.keys()) print( "printing charts to visdom env '%s' (%s:%d)" % (visdom_env, visdom_server, visdom_port) ) novisdom = False viz = get_visdom_connection(server=visdom_server, port=visdom_port) if not viz.check_connection(): print("no visdom server! -> skipping visdom plots") novisdom = True lines = [] # plot metrics if not novisdom: viz.close(env=visdom_env, win=None) for stat in self.log_vars: vals = [] stat_sets_now = [] for stat_set in stat_sets: val = self.stats[stat_set][stat].get_epoch_averages() if val is None: continue else: val = np.array(val).reshape(-1) stat_sets_now.append(stat_set) vals.append(val) if len(vals) == 0: continue lines.append((stat_sets_now, stat, vals)) if not novisdom: for tmodes, stat, vals in lines: title = "%s" % stat opts = {"title": title, "legend": list(tmodes)} for i, (tmode, val) in enumerate(zip(tmodes, vals)): update = "append" if i > 0 else None valid = np.where(np.isfinite(val))[0] if len(valid) == 0: continue x = np.arange(len(val)) viz.line( Y=val[valid], X=x[valid], env=visdom_env, opts=opts, win=f"stat_plot_{title}", name=tmode, update=update, ) if plot_file: print("exporting stats to %s" % plot_file) ncol = 3 nrow = int(np.ceil(float(len(lines)) / ncol)) matplotlib.rcParams.update({"font.size": 5}) color = cycle(plt.cm.tab10(np.linspace(0, 1, 10))) fig = plt.figure(1) plt.clf() for idx, (tmodes, stat, vals) in enumerate(lines): c = next(color) plt.subplot(nrow, ncol, idx + 1) plt.gca() for vali, vals_ in enumerate(vals): c_ = c * (1.0 - float(vali) * 0.3) valid = np.where(np.isfinite(vals_))[0] if len(valid) == 0: continue x = np.arange(len(vals_)) plt.plot(x[valid], vals_[valid], c=c_, linewidth=1) plt.ylabel(stat) plt.xlabel("epoch") plt.gca().yaxis.label.set_color(c[0:3] * 0.75) plt.legend(tmodes) gcolor = np.array(mcolors.to_rgba("lightgray")) plt.grid( b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4 ) plt.grid( b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2 ) plt.minorticks_on() plt.tight_layout() plt.show() try: fig.savefig(plot_file) except PermissionError: warnings.warn("Cant dump stats due to insufficient permissions!") def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True): stat_sets = list(self.stats.keys()) # remove the additional log_vars for stat_set in stat_sets: for stat in self.stats[stat_set].keys(): if stat not in log_vars: print("additional stat %s:%s -> removing" % (stat_set, stat)) self.stats[stat_set] = { stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars } self.log_vars = log_vars # !!! for stat_set in stat_sets: reference_stat = list(self.stats[stat_set].keys())[0] for stat in log_vars: if stat not in self.stats[stat_set]: if verbose: print( "missing stat %s:%s -> filling with default values (%1.2f)" % (stat_set, stat, default_val) ) elif len(self.stats[stat_set][stat].history) != self.epoch + 1: h = self.stats[stat_set][stat].history if len(h) == 0: # just never updated stat ... skip continue else: if verbose: print( "incomplete stat %s:%s -> reseting with default values (%1.2f)" % (stat_set, stat, default_val) ) else: continue self.stats[stat_set][stat] = AverageMeter() self.stats[stat_set][stat].reset() lastep = self.epoch + 1 for ep in range(lastep): self.stats[stat_set][stat].update(default_val, n=1, epoch=ep) epoch_self = self.stats[stat_set][reference_stat].get_epoch() epoch_generated = self.stats[stat_set][stat].get_epoch() assert ( epoch_self == epoch_generated ), "bad epoch of synchronized log_var! %d vs %d" % ( epoch_self, epoch_generated, ) class StatsJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, (AverageMeter, Stats)): enc = self.encode(o.__dict__) return enc else: raise TypeError( f"Object of type {o.__class__.__name__} " f"is not JSON serializable" ) def _get_postfixed_filename(fl, postfix): return fl if fl.endswith(postfix) else fl + postfix