mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
format style
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from .logging import get_logger
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -20,7 +22,7 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
"""
|
||||
last = scalars[0]
|
||||
smoothed = list()
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||
for next_val in scalars:
|
||||
smoothed_val = last * weight + (1 - weight) * next_val
|
||||
smoothed.append(smoothed_val)
|
||||
@@ -29,7 +31,6 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user