diff --git a/pytorch3d/implicitron/tools/stats.py b/pytorch3d/implicitron/tools/stats.py index 8b7c0bcc..012ab54a 100644 --- a/pytorch3d/implicitron/tools/stats.py +++ b/pytorch3d/implicitron/tools/stats.py @@ -118,6 +118,7 @@ class Stats(object): self.plot_file = plot_file self.do_plot = do_plot self.hard_reset(epoch=epoch) + self._t_last_update = None @staticmethod def from_json_str(json_str): @@ -215,7 +216,6 @@ class Stats(object): self.it[stat_set] += 1 epoch = self.epoch - it = self.it[stat_set] for stat in self.log_vars: @@ -224,10 +224,11 @@ class Stats(object): if stat == "sec/it": # compute speed if time_start is None: - elapsed = 0.0 + time_per_it = 0.0 else: - elapsed = time.time() - time_start - time_per_it = float(elapsed) / float(it + 1) + now = time.time() + time_per_it = now - (self._t_last_update or time_start) + self._t_last_update = now val = time_per_it else: if stat in preds: