Softly deprecate the get_str=False flag.

Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller.

Reviewed By: shapovalov

Differential Revision: D45356240

fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
This commit is contained in:
Virendra Kumar Pathak 2023-05-14 01:24:31 -07:00 committed by Facebook GitHub Bot
parent 297020a4b1
commit d08fe6d45a
2 changed files with 58 additions and 34 deletions

View File

@ -256,7 +256,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
@ -382,7 +381,8 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
# print textual status update
if it % self.metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
logger.info(std_out)
# visualize results
if (

View File

@ -6,6 +6,7 @@
import gzip
import json
import logging
import time
import warnings
from collections.abc import Iterable
@ -17,6 +18,8 @@ import numpy as np
from matplotlib import colors as mcolors
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
logger = logging.getLogger(__name__)
class AverageMeter(object):
"""Computes and stores the average and current value"""
@ -91,7 +94,9 @@ class Stats(object):
# stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db
stats.update(output)
stats.print() # prints the averages over given epoch
# prints the metric averages over given epoch
std_out = stats.get_status_string()
logger.info(str_out)
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
@ -101,7 +106,6 @@ class Stats(object):
def __init__(
self,
log_vars,
verbose=False,
epoch=-1,
visdom_env="main",
do_plot=True,
@ -110,7 +114,6 @@ class Stats(object):
visdom_port=8097,
):
self.verbose = verbose
self.log_vars = log_vars
self.visdom_env = visdom_env
self.visdom_server = visdom_server
@ -156,15 +159,14 @@ class Stats(object):
iserr = type is not None and issubclass(type, Exception)
iserr = iserr or (type is KeyboardInterrupt)
if iserr:
print("error inside 'with' block")
logger.error("error inside 'with' block")
return
if self.do_plot:
self.plot_stats(self.visdom_env)
def reset(self): # to be called after each epoch
stat_sets = list(self.stats.keys())
if self.verbose:
print("stats: epoch %d - reset" % self.epoch)
logger.debug(f"stats: epoch {self.epoch} - reset")
self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets:
for stat in self.stats[stat_set]:
@ -172,16 +174,14 @@ class Stats(object):
def hard_reset(self, epoch=-1): # to be called during object __init__
self.epoch = epoch
if self.verbose:
print("stats: epoch %d - hard reset" % self.epoch)
logger.debug(f"stats: epoch {self.epoch} - hard reset")
self.stats = {}
# reset
self.reset()
def new_epoch(self):
if self.verbose:
print("stats: new epoch %d" % (self.epoch + 1))
logger.debug(f"stats: new epoch {(self.epoch + 1)}")
self.epoch += 1
self.reset() # zero the stats + increase epoch counter
@ -193,18 +193,17 @@ class Stats(object):
val = float(val.sum())
return val
def add_log_vars(self, added_log_vars, verbose=True):
def add_log_vars(self, added_log_vars):
for add_log_var in added_log_vars:
if add_log_var not in self.stats:
if verbose:
print(f"Adding {add_log_var}")
logger.debug(f"Adding {add_log_var}")
self.log_vars.append(add_log_var)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized
print(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
logger.warning(
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
)
self.new_epoch()
@ -284,6 +283,12 @@ class Stats(object):
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
"""
stats.print() is deprecated. Please use get_status_string() instead.
example:
std_out = stats.get_status_string()
logger.info(str_out)
"""
epoch = self.epoch
stats = self.stats
@ -311,8 +316,30 @@ class Stats(object):
if get_str:
return str_out
else:
warnings.warn(
"get_str=False is deprecated."
"Please enable this flag to get receive the output string.",
DeprecationWarning,
)
print(str_out)
def get_status_string(
self,
max_it=None,
stat_set="train",
vars_print=None,
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
return self.print(
max_it=max_it,
stat_set=stat_set,
vars_print=vars_print,
get_str=True,
skip_nan=skip_nan,
stat_format=stat_format,
)
def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
):
@ -329,16 +356,15 @@ class Stats(object):
stat_sets = list(self.stats.keys())
print(
"printing charts to visdom env '%s' (%s:%d)"
% (visdom_env, visdom_server, visdom_port)
logger.debug(
f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
)
novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if viz is None or not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
logger.info("no visdom server! -> skipping visdom plots")
novisdom = True
lines = []
@ -385,7 +411,7 @@ class Stats(object):
)
if plot_file:
print("exporting stats to %s" % plot_file)
logger.info(f"plotting stats to {plot_file}")
ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5})
@ -423,7 +449,7 @@ class Stats(object):
except PermissionError:
warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
stat_sets = list(self.stats.keys())
@ -431,7 +457,7 @@ class Stats(object):
for stat_set in stat_sets:
for stat in self.stats[stat_set].keys():
if stat not in log_vars:
print("additional stat %s:%s -> removing" % (stat_set, stat))
logger.warning(f"additional stat {stat_set}:{stat} -> removing")
self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
@ -442,8 +468,7 @@ class Stats(object):
for stat_set in stat_sets:
for stat in log_vars:
if stat not in self.stats[stat_set]:
if verbose:
print(
logger.info(
"missing stat %s:%s -> filling with default values (%1.2f)"
% (stat_set, stat, default_val)
)
@ -452,8 +477,7 @@ class Stats(object):
if len(h) == 0: # just never updated stat ... skip
continue
else:
if verbose:
print(
logger.info(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
% (stat_set, stat, default_val)
)