mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Softly deprecate the get_str=False flag.
Summary: We don't want to use print directly in stats.print() method. Instead this method will return the output string to the caller. Reviewed By: shapovalov Differential Revision: D45356240 fbshipit-source-id: 2cabe3cdfb9206bf09aa7b3cdd2263148a5ba145
This commit is contained in:
parent
297020a4b1
commit
d08fe6d45a
@ -256,7 +256,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
list(log_vars),
|
list(log_vars),
|
||||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
||||||
visdom_env=visdom_env_charts,
|
visdom_env=visdom_env_charts,
|
||||||
verbose=False,
|
|
||||||
visdom_server=self.visdom_server,
|
visdom_server=self.visdom_server,
|
||||||
visdom_port=self.visdom_port,
|
visdom_port=self.visdom_port,
|
||||||
)
|
)
|
||||||
@ -382,7 +381,8 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
|
|
||||||
# print textual status update
|
# print textual status update
|
||||||
if it % self.metric_print_interval == 0 or last_iter:
|
if it % self.metric_print_interval == 0 or last_iter:
|
||||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
|
||||||
|
logger.info(std_out)
|
||||||
|
|
||||||
# visualize results
|
# visualize results
|
||||||
if (
|
if (
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@ -17,6 +18,8 @@ import numpy as np
|
|||||||
from matplotlib import colors as mcolors
|
from matplotlib import colors as mcolors
|
||||||
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
|
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter(object):
|
class AverageMeter(object):
|
||||||
"""Computes and stores the average and current value"""
|
"""Computes and stores the average and current value"""
|
||||||
@ -91,7 +94,9 @@ class Stats(object):
|
|||||||
# stats.update() automatically parses the 'objective' and 'top1e' from
|
# stats.update() automatically parses the 'objective' and 'top1e' from
|
||||||
# the "output" dict and stores this into the db
|
# the "output" dict and stores this into the db
|
||||||
stats.update(output)
|
stats.update(output)
|
||||||
stats.print() # prints the averages over given epoch
|
# prints the metric averages over given epoch
|
||||||
|
std_out = stats.get_status_string()
|
||||||
|
logger.info(str_out)
|
||||||
# stores the training plots into '/tmp/epoch_stats.pdf'
|
# stores the training plots into '/tmp/epoch_stats.pdf'
|
||||||
# and plots into a visdom server running at localhost (if running)
|
# and plots into a visdom server running at localhost (if running)
|
||||||
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
||||||
@ -101,7 +106,6 @@ class Stats(object):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_vars,
|
log_vars,
|
||||||
verbose=False,
|
|
||||||
epoch=-1,
|
epoch=-1,
|
||||||
visdom_env="main",
|
visdom_env="main",
|
||||||
do_plot=True,
|
do_plot=True,
|
||||||
@ -110,7 +114,6 @@ class Stats(object):
|
|||||||
visdom_port=8097,
|
visdom_port=8097,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.verbose = verbose
|
|
||||||
self.log_vars = log_vars
|
self.log_vars = log_vars
|
||||||
self.visdom_env = visdom_env
|
self.visdom_env = visdom_env
|
||||||
self.visdom_server = visdom_server
|
self.visdom_server = visdom_server
|
||||||
@ -156,15 +159,14 @@ class Stats(object):
|
|||||||
iserr = type is not None and issubclass(type, Exception)
|
iserr = type is not None and issubclass(type, Exception)
|
||||||
iserr = iserr or (type is KeyboardInterrupt)
|
iserr = iserr or (type is KeyboardInterrupt)
|
||||||
if iserr:
|
if iserr:
|
||||||
print("error inside 'with' block")
|
logger.error("error inside 'with' block")
|
||||||
return
|
return
|
||||||
if self.do_plot:
|
if self.do_plot:
|
||||||
self.plot_stats(self.visdom_env)
|
self.plot_stats(self.visdom_env)
|
||||||
|
|
||||||
def reset(self): # to be called after each epoch
|
def reset(self): # to be called after each epoch
|
||||||
stat_sets = list(self.stats.keys())
|
stat_sets = list(self.stats.keys())
|
||||||
if self.verbose:
|
logger.debug(f"stats: epoch {self.epoch} - reset")
|
||||||
print("stats: epoch %d - reset" % self.epoch)
|
|
||||||
self.it = {k: -1 for k in stat_sets}
|
self.it = {k: -1 for k in stat_sets}
|
||||||
for stat_set in stat_sets:
|
for stat_set in stat_sets:
|
||||||
for stat in self.stats[stat_set]:
|
for stat in self.stats[stat_set]:
|
||||||
@ -172,16 +174,14 @@ class Stats(object):
|
|||||||
|
|
||||||
def hard_reset(self, epoch=-1): # to be called during object __init__
|
def hard_reset(self, epoch=-1): # to be called during object __init__
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
if self.verbose:
|
logger.debug(f"stats: epoch {self.epoch} - hard reset")
|
||||||
print("stats: epoch %d - hard reset" % self.epoch)
|
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
|
|
||||||
# reset
|
# reset
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def new_epoch(self):
|
def new_epoch(self):
|
||||||
if self.verbose:
|
logger.debug(f"stats: new epoch {(self.epoch + 1)}")
|
||||||
print("stats: new epoch %d" % (self.epoch + 1))
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self.reset() # zero the stats + increase epoch counter
|
self.reset() # zero the stats + increase epoch counter
|
||||||
|
|
||||||
@ -193,18 +193,17 @@ class Stats(object):
|
|||||||
val = float(val.sum())
|
val = float(val.sum())
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def add_log_vars(self, added_log_vars, verbose=True):
|
def add_log_vars(self, added_log_vars):
|
||||||
for add_log_var in added_log_vars:
|
for add_log_var in added_log_vars:
|
||||||
if add_log_var not in self.stats:
|
if add_log_var not in self.stats:
|
||||||
if verbose:
|
logger.debug(f"Adding {add_log_var}")
|
||||||
print(f"Adding {add_log_var}")
|
|
||||||
self.log_vars.append(add_log_var)
|
self.log_vars.append(add_log_var)
|
||||||
|
|
||||||
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
||||||
|
|
||||||
if self.epoch == -1: # uninitialized
|
if self.epoch == -1: # uninitialized
|
||||||
print(
|
logger.warning(
|
||||||
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
"epoch==-1 means uninitialized stats structure -> new_epoch() called"
|
||||||
)
|
)
|
||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
|
|
||||||
@ -284,6 +283,12 @@ class Stats(object):
|
|||||||
skip_nan=False,
|
skip_nan=False,
|
||||||
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
|
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
stats.print() is deprecated. Please use get_status_string() instead.
|
||||||
|
example:
|
||||||
|
std_out = stats.get_status_string()
|
||||||
|
logger.info(str_out)
|
||||||
|
"""
|
||||||
|
|
||||||
epoch = self.epoch
|
epoch = self.epoch
|
||||||
stats = self.stats
|
stats = self.stats
|
||||||
@ -311,8 +316,30 @@ class Stats(object):
|
|||||||
if get_str:
|
if get_str:
|
||||||
return str_out
|
return str_out
|
||||||
else:
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
"get_str=False is deprecated."
|
||||||
|
"Please enable this flag to get receive the output string.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
print(str_out)
|
print(str_out)
|
||||||
|
|
||||||
|
def get_status_string(
|
||||||
|
self,
|
||||||
|
max_it=None,
|
||||||
|
stat_set="train",
|
||||||
|
vars_print=None,
|
||||||
|
skip_nan=False,
|
||||||
|
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
|
||||||
|
):
|
||||||
|
return self.print(
|
||||||
|
max_it=max_it,
|
||||||
|
stat_set=stat_set,
|
||||||
|
vars_print=vars_print,
|
||||||
|
get_str=True,
|
||||||
|
skip_nan=skip_nan,
|
||||||
|
stat_format=stat_format,
|
||||||
|
)
|
||||||
|
|
||||||
def plot_stats(
|
def plot_stats(
|
||||||
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
|
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
|
||||||
):
|
):
|
||||||
@ -329,16 +356,15 @@ class Stats(object):
|
|||||||
|
|
||||||
stat_sets = list(self.stats.keys())
|
stat_sets = list(self.stats.keys())
|
||||||
|
|
||||||
print(
|
logger.debug(
|
||||||
"printing charts to visdom env '%s' (%s:%d)"
|
f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})"
|
||||||
% (visdom_env, visdom_server, visdom_port)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
novisdom = False
|
novisdom = False
|
||||||
|
|
||||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||||
if viz is None or not viz.check_connection():
|
if viz is None or not viz.check_connection():
|
||||||
print("no visdom server! -> skipping visdom plots")
|
logger.info("no visdom server! -> skipping visdom plots")
|
||||||
novisdom = True
|
novisdom = True
|
||||||
|
|
||||||
lines = []
|
lines = []
|
||||||
@ -385,7 +411,7 @@ class Stats(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if plot_file:
|
if plot_file:
|
||||||
print("exporting stats to %s" % plot_file)
|
logger.info(f"plotting stats to {plot_file}")
|
||||||
ncol = 3
|
ncol = 3
|
||||||
nrow = int(np.ceil(float(len(lines)) / ncol))
|
nrow = int(np.ceil(float(len(lines)) / ncol))
|
||||||
matplotlib.rcParams.update({"font.size": 5})
|
matplotlib.rcParams.update({"font.size": 5})
|
||||||
@ -423,7 +449,7 @@ class Stats(object):
|
|||||||
except PermissionError:
|
except PermissionError:
|
||||||
warnings.warn("Cant dump stats due to insufficient permissions!")
|
warnings.warn("Cant dump stats due to insufficient permissions!")
|
||||||
|
|
||||||
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
|
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
|
||||||
|
|
||||||
stat_sets = list(self.stats.keys())
|
stat_sets = list(self.stats.keys())
|
||||||
|
|
||||||
@ -431,7 +457,7 @@ class Stats(object):
|
|||||||
for stat_set in stat_sets:
|
for stat_set in stat_sets:
|
||||||
for stat in self.stats[stat_set].keys():
|
for stat in self.stats[stat_set].keys():
|
||||||
if stat not in log_vars:
|
if stat not in log_vars:
|
||||||
print("additional stat %s:%s -> removing" % (stat_set, stat))
|
logger.warning(f"additional stat {stat_set}:{stat} -> removing")
|
||||||
|
|
||||||
self.stats[stat_set] = {
|
self.stats[stat_set] = {
|
||||||
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
|
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
|
||||||
@ -442,21 +468,19 @@ class Stats(object):
|
|||||||
for stat_set in stat_sets:
|
for stat_set in stat_sets:
|
||||||
for stat in log_vars:
|
for stat in log_vars:
|
||||||
if stat not in self.stats[stat_set]:
|
if stat not in self.stats[stat_set]:
|
||||||
if verbose:
|
logger.info(
|
||||||
print(
|
"missing stat %s:%s -> filling with default values (%1.2f)"
|
||||||
"missing stat %s:%s -> filling with default values (%1.2f)"
|
% (stat_set, stat, default_val)
|
||||||
% (stat_set, stat, default_val)
|
)
|
||||||
)
|
|
||||||
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
|
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
|
||||||
h = self.stats[stat_set][stat].history
|
h = self.stats[stat_set][stat].history
|
||||||
if len(h) == 0: # just never updated stat ... skip
|
if len(h) == 0: # just never updated stat ... skip
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if verbose:
|
logger.info(
|
||||||
print(
|
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
|
||||||
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
|
% (stat_set, stat, default_val)
|
||||||
% (stat_set, stat, default_val)
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user