mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			96 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			96 lines
		
	
	
		
			3.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2025 the LlamaFactory team.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import json
 | 
						|
import math
 | 
						|
import os
 | 
						|
from typing import Any
 | 
						|
 | 
						|
from transformers.trainer import TRAINER_STATE_NAME
 | 
						|
 | 
						|
from . import logging
 | 
						|
from .packages import is_matplotlib_available
 | 
						|
 | 
						|
 | 
						|
if is_matplotlib_available():
 | 
						|
    import matplotlib.figure
 | 
						|
    import matplotlib.pyplot as plt
 | 
						|
 | 
						|
 | 
						|
logger = logging.get_logger(__name__)
 | 
						|
 | 
						|
 | 
						|
def smooth(scalars: list[float]) -> list[float]:
 | 
						|
    r"""EMA implementation according to TensorBoard."""
 | 
						|
    if len(scalars) == 0:
 | 
						|
        return []
 | 
						|
 | 
						|
    last = scalars[0]
 | 
						|
    smoothed = []
 | 
						|
    weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5)  # a sigmoid function
 | 
						|
    for next_val in scalars:
 | 
						|
        smoothed_val = last * weight + (1 - weight) * next_val
 | 
						|
        smoothed.append(smoothed_val)
 | 
						|
        last = smoothed_val
 | 
						|
    return smoothed
 | 
						|
 | 
						|
 | 
						|
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
 | 
						|
    r"""Plot loss curves in LlamaBoard."""
 | 
						|
    plt.close("all")
 | 
						|
    plt.switch_backend("agg")
 | 
						|
    fig = plt.figure()
 | 
						|
    ax = fig.add_subplot(111)
 | 
						|
    steps, losses = [], []
 | 
						|
    for log in trainer_log:
 | 
						|
        if log.get("loss", None):
 | 
						|
            steps.append(log["current_steps"])
 | 
						|
            losses.append(log["loss"])
 | 
						|
 | 
						|
    ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
 | 
						|
    ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
 | 
						|
    ax.legend()
 | 
						|
    ax.set_xlabel("step")
 | 
						|
    ax.set_ylabel("loss")
 | 
						|
    return fig
 | 
						|
 | 
						|
 | 
						|
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
 | 
						|
    r"""Plot loss curves and saves the image."""
 | 
						|
    plt.switch_backend("agg")
 | 
						|
    with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
 | 
						|
        data = json.load(f)
 | 
						|
 | 
						|
    for key in keys:
 | 
						|
        steps, metrics = [], []
 | 
						|
        for i in range(len(data["log_history"])):
 | 
						|
            if key in data["log_history"][i]:
 | 
						|
                steps.append(data["log_history"][i]["step"])
 | 
						|
                metrics.append(data["log_history"][i][key])
 | 
						|
 | 
						|
        if len(metrics) == 0:
 | 
						|
            logger.warning_rank0(f"No metric {key} to plot.")
 | 
						|
            continue
 | 
						|
 | 
						|
        plt.figure()
 | 
						|
        plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
 | 
						|
        plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
 | 
						|
        plt.title(f"training {key} of {save_dictionary}")
 | 
						|
        plt.xlabel("step")
 | 
						|
        plt.ylabel(key)
 | 
						|
        plt.legend()
 | 
						|
        figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
 | 
						|
        plt.savefig(figure_path, format="png", dpi=100)
 | 
						|
        print("Figure saved at:", figure_path)
 |