fix bug in latest gradio

Former-commit-id: 44a962862b4a74e50ef5786c8d5719faaa65f63f
This commit is contained in:
hiyouga
2024-04-04 00:55:31 +08:00
parent 43d134ba29
commit b1986a06b9
8 changed files with 111 additions and 204 deletions

View File

@@ -1,13 +1,12 @@
import json
import os
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS
@@ -36,7 +35,7 @@ def get_time() -> str:
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora":
return gr.Dropdown(value="None", interactive=False)
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
@@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
log_file = os.path.join(output_path, "trainer_log.jsonl")
if not os.path.isfile(log_file) or not is_matplotlib_available():
return
plt.close("all")
@@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
if len(losses) == 0:
return None
return
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")