From 5b7491188156955a7a6fb9fe333350d43405f6cb Mon Sep 17 00:00:00 2001 From: David Novotny Date: Tue, 2 Feb 2021 05:42:59 -0800 Subject: [PATCH] NeRF training stats logger. Summary: Implements the `Stats` class that handles logging of the training statistics. Reviewed By: nikhilaravi Differential Revision: D25684430 fbshipit-source-id: 920a1c65917ab5d047988494d92173da60cfd64b --- projects/nerf/nerf/stats.py | 341 ++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 projects/nerf/nerf/stats.py diff --git a/projects/nerf/nerf/stats.py b/projects/nerf/nerf/stats.py new file mode 100644 index 00000000..e2347ebc --- /dev/null +++ b/projects/nerf/nerf/stats.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import time +import warnings +from itertools import cycle +from typing import List, Optional + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import colors as mcolors +from visdom import Visdom + + +class AverageMeter(object): + """ + Computes and stores the average and current value. + Tracks the exact history of the added values in every epoch. + """ + + def __init__(self): + """ + Initialize the structure with empty history and zero-ed moving average. + """ + self.history = [] + self.reset() + + def reset(self): + """ + Reset the running average meter. + """ + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val: float, n: int = 1, epoch: int = 0): + """ + Updates the average meter with a value `val`. + + Args: + val: A float to be added to the meter. + n: Represents the number of entities to be added. + epoch: The epoch to which the number should be added. + """ + # make sure the history is of the same len as epoch + while len(self.history) <= epoch: + self.history.append([]) + self.history[epoch].append(val / n) + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def get_epoch_averages(self): + """ + Returns: + averages: A list of average values of the metric for each epoch + in the history buffer. + """ + if len(self.history) == 0: + return None + return [ + (float(np.array(h).mean()) if len(h) > 0 else float("NaN")) + for h in self.history + ] + + +class Stats(object): + """ + Stats logging object useful for gathering statistics of training + a deep network in PyTorch. + + Example: + ``` + # Init stats structure that logs statistics 'objective' and 'top1e'. + stats = Stats( ('objective','top1e') ) + + network = init_net() # init a pytorch module (=neural network) + dataloader = init_dataloader() # init a dataloader + + for epoch in range(10): + + # start of epoch -> call new_epoch + stats.new_epoch() + + # Iterate over batches. + for batch in dataloader: + # Run a model and save into a dict of output variables "output" + output = network(batch) + + # stats.update() automatically parses the 'objective' and 'top1e' + # from the "output" dict and stores this into the db. + stats.update(output) + stats.print() # prints the averages over given epoch + + # Stores the training plots into '/tmp/epoch_stats.pdf' + # and plots into a visdom server running at localhost (if running). + stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') + ``` + """ + + def __init__( + self, + log_vars: List[str], + verbose: bool = False, + epoch: int = -1, + plot_file: Optional[str] = None, + ): + """ + Args: + log_vars: The list of variable names to be logged. + verbose: Print status messages. + epoch: The initial epoch of the object. + plot_file: The path to the file that will hold the training plots. + """ + self.verbose = verbose + self.log_vars = log_vars + self.plot_file = plot_file + self.hard_reset(epoch=epoch) + + def reset(self): + """ + Called before an epoch to clear current epoch buffers. + """ + stat_sets = list(self.stats.keys()) + if self.verbose: + print("stats: epoch %d - reset" % self.epoch) + self.it = {k: -1 for k in stat_sets} + for stat_set in stat_sets: + for stat in self.stats[stat_set]: + self.stats[stat_set][stat].reset() + + # Set a new timestamp. + self._epoch_start = time.time() + + def hard_reset(self, epoch: int = -1): + """ + Erases all logged data. + """ + self._epoch_start = None + self.epoch = epoch + if self.verbose: + print("stats: epoch %d - hard reset" % self.epoch) + self.stats = {} + self.reset() + + def new_epoch(self): + """ + Initializes a new epoch. + """ + if self.verbose: + print("stats: new epoch %d" % (self.epoch + 1)) + self.epoch += 1 # increase epoch counter + self.reset() # zero the stats + + def _gather_value(self, val): + if isinstance(val, float): + pass + else: + val = val.data.cpu().numpy() + val = float(val.sum()) + return val + + def update(self, preds: dict, stat_set: str = "train"): + """ + Update the internal logs with metrics of a training step. + + Each metric is stored as an instance of an AverageMeter. + + Args: + preds: Dict of values to be added to the logs. + stat_set: The set of statistics to be updated (e.g. "train", "val"). + """ + + if self.epoch == -1: # uninitialized + warnings.warn( + "self.epoch==-1 means uninitialized stats structure" + " -> new_epoch() called" + ) + self.new_epoch() + + if stat_set not in self.stats: + self.stats[stat_set] = {} + self.it[stat_set] = -1 + + self.it[stat_set] += 1 + + epoch = self.epoch + it = self.it[stat_set] + + for stat in self.log_vars: + + if stat not in self.stats[stat_set]: + self.stats[stat_set][stat] = AverageMeter() + + if stat == "sec/it": # compute speed + elapsed = time.time() - self._epoch_start + time_per_it = float(elapsed) / float(it + 1) + val = time_per_it + else: + if stat in preds: + val = self._gather_value(preds[stat]) + else: + val = None + + if val is not None: + self.stats[stat_set][stat].update(val, epoch=epoch, n=1) + + def print(self, max_it: Optional[int] = None, stat_set: str = "train"): + """ + Print the current values of all stored stats. + + Args: + max_it: Maximum iteration number to be displayed. + If None, the maximum iteration number is not displayed. + stat_set: The set of statistics to be printed. + """ + + epoch = self.epoch + stats = self.stats + + str_out = "" + + it = self.it[stat_set] + stat_str = "" + stats_print = sorted(stats[stat_set].keys()) + for stat in stats_print: + if stats[stat_set][stat].count == 0: + continue + stat_str += " {0:.12}: {1:1.3f} |".format(stat, stats[stat_set][stat].avg) + + head_str = f"[{stat_set}] | epoch {epoch} | it {it}" + if max_it: + head_str += f"/ {max_it}" + + str_out = f"{head_str} | {stat_str}" + + print(str_out) + + def plot_stats( + self, + viz: Visdom = None, + visdom_env: Optional[str] = None, + plot_file: Optional[str] = None, + ): + """ + Plot the line charts of the history of the stats. + + Args: + viz: The Visdom object holding the connection to a Visdom server. + visdom_env: The visdom environment for storing the graphs. + plot_file: The path to a file with training plots. + """ + + stat_sets = list(self.stats.keys()) + + if viz is None: + withvisdom = False + elif not viz.check_connection(): + warnings.warn("Cannot connect to the visdom server! Skipping visdom plots.") + withvisdom = False + else: + withvisdom = True + + lines = [] + + for stat in self.log_vars: + vals = [] + stat_sets_now = [] + for stat_set in stat_sets: + val = self.stats[stat_set][stat].get_epoch_averages() + if val is None: + continue + else: + val = np.array(val).reshape(-1) + stat_sets_now.append(stat_set) + vals.append(val) + + if len(vals) == 0: + continue + + vals = np.stack(vals, axis=1) + x = np.arange(vals.shape[0]) + + lines.append((stat_sets_now, stat, x, vals)) + + if withvisdom: + for tmodes, stat, x, vals in lines: + title = "%s" % stat + opts = {"title": title, "legend": list(tmodes)} + for i, (tmode, val) in enumerate(zip(tmodes, vals.T)): + update = "append" if i > 0 else None + valid = np.where(np.isfinite(val)) + if len(valid) == 0: + continue + viz.line( + Y=val[valid], + X=x[valid], + env=visdom_env, + opts=opts, + win=f"stat_plot_{title}", + name=tmode, + update=update, + ) + + if plot_file is None: + plot_file = self.plot_file + + if plot_file is not None: + print("Exporting stats to %s" % plot_file) + ncol = 3 + nrow = int(np.ceil(float(len(lines)) / ncol)) + matplotlib.rcParams.update({"font.size": 5}) + color = cycle(plt.cm.tab10(np.linspace(0, 1, 10))) + fig = plt.figure(1) + plt.clf() + for idx, (tmodes, stat, x, vals) in enumerate(lines): + c = next(color) + plt.subplot(nrow, ncol, idx + 1) + for vali, vals_ in enumerate(vals.T): + c_ = c * (1.0 - float(vali) * 0.3) + valid = np.where(np.isfinite(vals_)) + if len(valid) == 0: + continue + plt.plot(x[valid], vals_[valid], c=c_, linewidth=1) + plt.ylabel(stat) + plt.xlabel("epoch") + plt.gca().yaxis.label.set_color(c[0:3] * 0.75) + plt.legend(tmodes) + gcolor = np.array(mcolors.to_rgba("lightgray")) + plt.grid( + b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4 + ) + plt.grid( + b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2 + ) + plt.minorticks_on() + + plt.tight_layout() + plt.show() + fig.savefig(plot_file)