mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Softly deprecate the get_str=False flag.
Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller. Reviewed By: shapovalov Differential Revision: D45356240 fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
This commit is contained in:
parent
297020a4b1
commit
d08fe6d45a
@ -256,7 +256,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
list(log_vars),
|
||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
@ -382,7 +381,8 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
||||
|
||||
# print textual status update
|
||||
if it % self.metric_print_interval == 0 or last_iter:
|
||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
|
||||
logger.info(std_out)
|
||||
|
||||
# visualize results
|
||||
if (
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
@ -17,6 +18,8 @@ import numpy as np
|
||||
from matplotlib import colors as mcolors
|
||||
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
@ -91,7 +94,9 @@ class Stats(object):
|
||||
# stats.update() automatically parses the 'objective' and 'top1e' from
|
||||
# the "output" dict and stores this into the db
|
||||
stats.update(output)
|
||||
stats.print() # prints the averages over given epoch
|
||||
# prints the metric averages over given epoch
|
||||
std_out = stats.get_status_string()
|
||||
logger.info(str_out)
|
||||
# stores the training plots into '/tmp/epoch_stats.pdf'
|
||||
# and plots into a visdom server running at localhost (if running)
|
||||
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
||||
@ -101,7 +106,6 @@ class Stats(object):
|
||||
def __init__(
|
||||
self,
|
||||
log_vars,
|
||||
verbose=False,
|
||||
epoch=-1,
|
||||
visdom_env="main",
|
||||
do_plot=True,
|
||||
@ -110,7 +114,6 @@ class Stats(object):
|
||||
visdom_port=8097,
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
self.log_vars = log_vars
|
||||
self.visdom_env = visdom_env
|
||||
self.visdom_server = visdom_server
|
||||
@ -156,15 +159,14 @@ class Stats(object):
|
||||
iserr = type is not None and issubclass(type, Exception)
|
||||
iserr = iserr or (type is KeyboardInterrupt)
|
||||
if iserr:
|
||||
print("error inside 'with' block")
|
||||
logger.error("error inside 'with' block")
|
||||
return
|
||||
if self.do_plot:
|
||||
self.plot_stats(self.visdom_env)
|
||||
|
||||
def reset(self): # to be called after each epoch
|
||||
stat_sets = list(self.stats.keys())
|
||||
if self.verbose:
|
||||
print("stats: epoch %d - reset" % self.epoch)
|
||||
logger.debug(f"stats: epoch {self.epoch} - reset")
|
||||
self.it = {k: -1 for k in stat_sets}
|
||||
for stat_set in stat_sets:
|
||||
for stat in self.stats[stat_set]:
|
||||
@ -172,16 +174,14 @@ class Stats(object):
|
||||
|
||||
def hard_reset(self, epoch=-1): # to be called during object __init__
|
||||
self.epoch = epoch
|
||||
if self.verbose:
|
||||
print("stats: epoch %d - hard reset" % self.epoch)
|
||||
logger.debug(f"stats: epoch {self.epoch} - hard reset")
|
||||
self.stats = {}
|
||||
|
||||
# reset
|
||||
self.reset()
|
||||
|
||||
def new_epoch(self):
|
||||
if self.verbose:
|
||||
print("stats: new epoch %d" % (self.epoch + 1))
|
||||
logger.debug(f"stats: new epoch {(self.epoch + 1)}")
|
||||
self.epoch += 1
|
||||
self.reset() # zero the stats + increase epoch counter
|
||||
|
||||
@ -193,18 +193,17 @@ class Stats(object):
|
||||
val = float(val.sum())
|
||||
return val
|
||||
|
||||
def add_log_vars(self, added_log_vars, verbose=True):
|
||||
def add_log_vars(self, added_log_vars):
|
||||
for add_log_var in added_log_vars:
|
||||
if add_log_var not in self.stats:
|
||||
if verbose:
|
||||
print(f"Adding {add_log_var}")
|
||||
logger.debug(f"Adding {add_log_var}")
|
||||
self.log_vars.append(add_log_var)
|
||||
|
||||
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
||||
|
||||
if self.epoch == -1: # uninitialized
|
||||
print(
|
||||
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
||||
logger.warning(
|
||||
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
||||
)
|
||||
self.new_epoch()
|
||||
|
||||
@ -284,6 +283,12 @@ class Stats(object):
|
||||
skip_nan=False,
|
||||
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
|
||||
):
|
||||
"""
|
||||
stats.print() is deprecated. Please use get_status_string() instead.
|
||||
example:
|
||||
std_out = stats.get_status_string()
|
||||
logger.info(str_out)
|
||||
"""
|
||||
|
||||
epoch = self.epoch
|
||||
stats = self.stats
|
||||
@ -311,8 +316,30 @@ class Stats(object):
|
||||
if get_str:
|
||||
return str_out
|
||||
else:
|
||||
warnings.warn(
|
||||
"get_str=False is deprecated."
|
||||
"Please enable this flag to get receive the output string.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
print(str_out)
|
||||
|
||||
def get_status_string(
|
||||
self,
|
||||
max_it=None,
|
||||
stat_set="train",
|
||||
vars_print=None,
|
||||
skip_nan=False,
|
||||
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
|
||||
):
|
||||
return self.print(
|
||||
max_it=max_it,
|
||||
stat_set=stat_set,
|
||||
vars_print=vars_print,
|
||||
get_str=True,
|
||||
skip_nan=skip_nan,
|
||||
stat_format=stat_format,
|
||||
)
|
||||
|
||||
def plot_stats(
|
||||
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
|
||||
):
|
||||
@ -329,16 +356,15 @@ class Stats(object):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
print(
|
||||
"printing charts to visdom env '%s' (%s:%d)"
|
||||
% (visdom_env, visdom_server, visdom_port)
|
||||
logger.debug(
|
||||
f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
|
||||
)
|
||||
|
||||
novisdom = False
|
||||
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
if viz is None or not viz.check_connection():
|
||||
print("no visdom server! -> skipping visdom plots")
|
||||
logger.info("no visdom server! -> skipping visdom plots")
|
||||
novisdom = True
|
||||
|
||||
lines = []
|
||||
@ -385,7 +411,7 @@ class Stats(object):
|
||||
)
|
||||
|
||||
if plot_file:
|
||||
print("exporting stats to %s" % plot_file)
|
||||
logger.info(f"plotting stats to {plot_file}")
|
||||
ncol = 3
|
||||
nrow = int(np.ceil(float(len(lines)) / ncol))
|
||||
matplotlib.rcParams.update({"font.size": 5})
|
||||
@ -423,7 +449,7 @@ class Stats(object):
|
||||
except PermissionError:
|
||||
warnings.warn("Cant dump stats due to insufficient permissions!")
|
||||
|
||||
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
|
||||
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
|
||||
|
||||
stat_sets = list(self.stats.keys())
|
||||
|
||||
@ -431,7 +457,7 @@ class Stats(object):
|
||||
for stat_set in stat_sets:
|
||||
for stat in self.stats[stat_set].keys():
|
||||
if stat not in log_vars:
|
||||
print("additional stat %s:%s -> removing" % (stat_set, stat))
|
||||
logger.warning(f"additional stat {stat_set}:{stat} -> removing")
|
||||
|
||||
self.stats[stat_set] = {
|
||||
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
|
||||
@ -442,21 +468,19 @@ class Stats(object):
|
||||
for stat_set in stat_sets:
|
||||
for stat in log_vars:
|
||||
if stat not in self.stats[stat_set]:
|
||||
if verbose:
|
||||
print(
|
||||
"missing stat %s:%s -> filling with default values (%1.2f)"
|
||||
% (stat_set, stat, default_val)
|
||||
)
|
||||
logger.info(
|
||||
"missing stat %s:%s -> filling with default values (%1.2f)"
|
||||
% (stat_set, stat, default_val)
|
||||
)
|
||||
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
|
||||
h = self.stats[stat_set][stat].history
|
||||
if len(h) == 0: # just never updated stat ... skip
|
||||
continue
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
|
||||
% (stat_set, stat, default_val)
|
||||
)
|
||||
logger.info(
|
||||
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
|
||||
% (stat_set, stat, default_val)
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user