mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 21:52:51 +08:00
parent
1c089ccfee
commit
5a5d450648
@ -21,6 +21,9 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
r"""
|
r"""
|
||||||
EMA implementation according to TensorBoard.
|
EMA implementation according to TensorBoard.
|
||||||
"""
|
"""
|
||||||
|
if len(scalars) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
last = scalars[0]
|
last = scalars[0]
|
||||||
smoothed = []
|
smoothed = []
|
||||||
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
||||||
@ -32,6 +35,9 @@ def smooth(scalars: List[float]) -> List[float]:
|
|||||||
|
|
||||||
|
|
||||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||||
|
r"""
|
||||||
|
Plots loss curves in LlamaBoard.
|
||||||
|
"""
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
plt.switch_backend("agg")
|
plt.switch_backend("agg")
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
@ -51,6 +57,9 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
|||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||||
|
r"""
|
||||||
|
Plots loss curves and saves the image.
|
||||||
|
"""
|
||||||
plt.switch_backend("agg")
|
plt.switch_backend("agg")
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user