Softly deprecate the get_str=False flag.

Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller.

Reviewed By: shapovalov

Differential Revision: D45356240

fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
This commit is contained in:
Virendra Kumar Pathak 2023-05-14 01:24:31 -07:00 committed by Facebook GitHub Bot
parent 297020a4b1
commit d08fe6d45a
2 changed files with 58 additions and 34 deletions

View File

@ -256,7 +256,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
list(log_vars), list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"), plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts, visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server, visdom_server=self.visdom_server,
visdom_port=self.visdom_port, visdom_port=self.visdom_port,
) )
@ -382,7 +381,8 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
# print textual status update # print textual status update
if it % self.metric_print_interval == 0 or last_iter: if it % self.metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches) std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
logger.info(std_out)
# visualize results # visualize results
if ( if (

View File

@ -6,6 +6,7 @@
import gzip import gzip
import json import json
import logging
import time import time
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
@ -17,6 +18,8 @@ import numpy as np
from matplotlib import colors as mcolors from matplotlib import colors as mcolors
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
logger = logging.getLogger(__name__)
class AverageMeter(object): class AverageMeter(object):
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
@ -91,7 +94,9 @@ class Stats(object):
# stats.update() automatically parses the 'objective' and 'top1e' from # stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db # the "output" dict and stores this into the db
stats.update(output) stats.update(output)
stats.print() # prints the averages over given epoch # prints the metric averages over given epoch
std_out = stats.get_status_string()
logger.info(str_out)
# stores the training plots into '/tmp/epoch_stats.pdf' # stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running) # and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
@ -101,7 +106,6 @@ class Stats(object):
def __init__( def __init__(
self, self,
log_vars, log_vars,
verbose=False,
epoch=-1, epoch=-1,
visdom_env="main", visdom_env="main",
do_plot=True, do_plot=True,
@ -110,7 +114,6 @@ class Stats(object):
visdom_port=8097, visdom_port=8097,
): ):
self.verbose = verbose
self.log_vars = log_vars self.log_vars = log_vars
self.visdom_env = visdom_env self.visdom_env = visdom_env
self.visdom_server = visdom_server self.visdom_server = visdom_server
@ -156,15 +159,14 @@ class Stats(object):
iserr = type is not None and issubclass(type, Exception) iserr = type is not None and issubclass(type, Exception)
iserr = iserr or (type is KeyboardInterrupt) iserr = iserr or (type is KeyboardInterrupt)
if iserr: if iserr:
print("error inside 'with' block") logger.error("error inside 'with' block")
return return
if self.do_plot: if self.do_plot:
self.plot_stats(self.visdom_env) self.plot_stats(self.visdom_env)
def reset(self): # to be called after each epoch def reset(self): # to be called after each epoch
stat_sets = list(self.stats.keys()) stat_sets = list(self.stats.keys())
if self.verbose: logger.debug(f"stats: epoch {self.epoch} - reset")
print("stats: epoch %d - reset" % self.epoch)
self.it = {k: -1 for k in stat_sets} self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets: for stat_set in stat_sets:
for stat in self.stats[stat_set]: for stat in self.stats[stat_set]:
@ -172,16 +174,14 @@ class Stats(object):
def hard_reset(self, epoch=-1): # to be called during object __init__ def hard_reset(self, epoch=-1): # to be called during object __init__
self.epoch = epoch self.epoch = epoch
if self.verbose: logger.debug(f"stats: epoch {self.epoch} - hard reset")
print("stats: epoch %d - hard reset" % self.epoch)
self.stats = {} self.stats = {}
# reset # reset
self.reset() self.reset()
def new_epoch(self): def new_epoch(self):
if self.verbose: logger.debug(f"stats: new epoch {(self.epoch + 1)}")
print("stats: new epoch %d" % (self.epoch + 1))
self.epoch += 1 self.epoch += 1
self.reset() # zero the stats + increase epoch counter self.reset() # zero the stats + increase epoch counter
@ -193,18 +193,17 @@ class Stats(object):
val = float(val.sum()) val = float(val.sum())
return val return val
def add_log_vars(self, added_log_vars, verbose=True): def add_log_vars(self, added_log_vars):
for add_log_var in added_log_vars: for add_log_var in added_log_vars:
if add_log_var not in self.stats: if add_log_var not in self.stats:
if verbose: logger.debug(f"Adding {add_log_var}")
print(f"Adding {add_log_var}")
self.log_vars.append(add_log_var) self.log_vars.append(add_log_var)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized if self.epoch == -1: # uninitialized
print( logger.warning(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called" "epoch==-1 means uninitialized stats structure -> new_epoch() called"
) )
self.new_epoch() self.new_epoch()
@ -284,6 +283,12 @@ class Stats(object):
skip_nan=False, skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"), stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
): ):
"""
stats.print() is deprecated. Please use get_status_string() instead.
example:
std_out = stats.get_status_string()
logger.info(str_out)
"""
epoch = self.epoch epoch = self.epoch
stats = self.stats stats = self.stats
@ -311,8 +316,30 @@ class Stats(object):
if get_str: if get_str:
return str_out return str_out
else: else:
warnings.warn(
"get_str=False is deprecated."
"Please enable this flag to get receive the output string.",
DeprecationWarning,
)
print(str_out) print(str_out)
def get_status_string(
self,
max_it=None,
stat_set="train",
vars_print=None,
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
return self.print(
max_it=max_it,
stat_set=stat_set,
vars_print=vars_print,
get_str=True,
skip_nan=skip_nan,
stat_format=stat_format,
)
def plot_stats( def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
): ):
@ -329,16 +356,15 @@ class Stats(object):
stat_sets = list(self.stats.keys()) stat_sets = list(self.stats.keys())
print( logger.debug(
"printing charts to visdom env '%s' (%s:%d)" f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
% (visdom_env, visdom_server, visdom_port)
) )
novisdom = False novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port) viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if viz is None or not viz.check_connection(): if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots") logger.info("no visdom server! -> skipping visdom plots")
novisdom = True novisdom = True
lines = [] lines = []
@ -385,7 +411,7 @@ class Stats(object):
) )
if plot_file: if plot_file:
print("exporting stats to %s" % plot_file) logger.info(f"plotting stats to {plot_file}")
ncol = 3 ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol)) nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5}) matplotlib.rcParams.update({"font.size": 5})
@ -423,7 +449,7 @@ class Stats(object):
except PermissionError: except PermissionError:
warnings.warn("Cant dump stats due to insufficient permissions!") warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True): def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
stat_sets = list(self.stats.keys()) stat_sets = list(self.stats.keys())
@ -431,7 +457,7 @@ class Stats(object):
for stat_set in stat_sets: for stat_set in stat_sets:
for stat in self.stats[stat_set].keys(): for stat in self.stats[stat_set].keys():
if stat not in log_vars: if stat not in log_vars:
print("additional stat %s:%s -> removing" % (stat_set, stat)) logger.warning(f"additional stat {stat_set}:{stat} -> removing")
self.stats[stat_set] = { self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
@ -442,21 +468,19 @@ class Stats(object):
for stat_set in stat_sets: for stat_set in stat_sets:
for stat in log_vars: for stat in log_vars:
if stat not in self.stats[stat_set]: if stat not in self.stats[stat_set]:
if verbose: logger.info(
print( "missing stat %s:%s -> filling with default values (%1.2f)"
"missing stat %s:%s -> filling with default values (%1.2f)" % (stat_set, stat, default_val)
% (stat_set, stat, default_val) )
)
elif len(self.stats[stat_set][stat].history) != self.epoch + 1: elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
h = self.stats[stat_set][stat].history h = self.stats[stat_set][stat].history
if len(h) == 0: # just never updated stat ... skip if len(h) == 0: # just never updated stat ... skip
continue continue
else: else:
if verbose: logger.info(
print( "incomplete stat %s:%s -> reseting with default values (%1.2f)"
"incomplete stat %s:%s -> reseting with default values (%1.2f)" % (stat_set, stat, default_val)
% (stat_set, stat, default_val) )
)
else: else:
continue continue