From 5a5d450648e8781083cd3b4ab96dad06806acecb Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 May 2024 20:37:21 +0800 Subject: [PATCH] fix #3728 Former-commit-id: cfaee8b4cf5f89d767a20a057d2335bd30ec83a2 --- src/llmtuner/extras/ploting.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index e53f1f89..dea23bbe 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -21,6 +21,9 @@ def smooth(scalars: List[float]) -> List[float]: r""" EMA implementation according to TensorBoard. """ + if len(scalars) == 0: + return [] + last = scalars[0] smoothed = [] weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function @@ -32,6 +35,9 @@ def smooth(scalars: List[float]) -> List[float]: def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": + r""" + Plots loss curves in LlamaBoard. + """ plt.close("all") plt.switch_backend("agg") fig = plt.figure() @@ -51,6 +57,9 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: + r""" + Plots loss curves and saves the image. + """ plt.switch_backend("agg") with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: data = json.load(f)