fix bug in latest gradio

Former-commit-id: 5ddcecda50ccff93d51bebc9ac72c2a0dd483e9b
This commit is contained in:
hiyouga 2024-04-04 00:55:31 +08:00
parent a6d347726f
commit 54a4a8217a
8 changed files with 111 additions and 204 deletions

View File

@ -4,7 +4,7 @@ datasets>=2.14.3
accelerate>=0.27.2 accelerate>=0.27.2
peft>=0.10.0 peft>=0.10.0
trl>=0.8.1 trl>=0.8.1
gradio>4.0.0,<=4.21.0 gradio>=4.0.0
scipy scipy
einops einops
sentencepiece sentencepiece

View File

@ -1,114 +0,0 @@
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
internlm2_state_dict.update(shard_weight)
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
if "output" in key:
llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "wqkv" in key:
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
elif "ffn_norm" in key:
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
elif "w1" in key:
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
elif "w2" in key:
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
elif "w3" in key:
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
else:
llama2_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
llama2_config_dict.pop("auto_map", None)
llama2_config_dict.pop("bias", None)
llama2_config_dict.pop("rope_scaling", None)
llama2_config_dict["model_type"] = "llama"
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
json.dump(llama2_config_dict, f, indent=2)
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e:
raise print("Output dir already exists", e)
save_weight(input_dir, output_dir, shard_size, save_safetensors)
save_config(input_dir, output_dir)
if __name__ == "__main__":
fire.Fire(llamafy_internlm2)

View File

@ -66,7 +66,7 @@ def check_dependencies() -> None:
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0") require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@ -21,8 +21,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
input_elems.update({dataset_dir, dataset}) input_elems.update({dataset_dir, dataset})
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
@ -50,7 +48,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn = gr.Button(variant="stop") stop_btn = gr.Button(variant="stop")
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False)
with gr.Row(): with gr.Row():
@ -73,4 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
stop_btn.click(engine.runner.set_abort) stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
return elem_dict return elem_dict

View File

@ -6,7 +6,6 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING: if TYPE_CHECKING:
@ -24,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1 choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
) )
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True) dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True)
preview_elems = create_preview_box(dataset_dir, dataset) preview_elems = create_preview_box(dataset_dir, dataset)
input_elems.update({training_stage, dataset_dir, dataset}) input_elems.update({training_stage, dataset_dir, dataset})
@ -121,8 +120,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as freeze_tab: with gr.Accordion(open=False) as freeze_tab:
with gr.Row(): with gr.Row():
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2) num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
name_module_trainable = gr.Textbox(value="all", scale=3) name_module_trainable = gr.Textbox(value="all")
input_elems.update({num_layer_trainable, name_module_trainable}) input_elems.update({num_layer_trainable, name_module_trainable})
elem_dict.update( elem_dict.update(
@ -140,8 +139,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
create_new_adapter = gr.Checkbox() create_new_adapter = gr.Checkbox()
with gr.Row(): with gr.Row():
use_rslora = gr.Checkbox(scale=1) with gr.Column(scale=1):
use_dora = gr.Checkbox(scale=1) use_rslora = gr.Checkbox()
use_dora = gr.Checkbox()
lora_target = gr.Textbox(scale=2) lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2)
@ -175,10 +176,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as rlhf_tab: with gr.Accordion(open=False) as rlhf_tab:
with gr.Row(): with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1) dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2) reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model}) input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
elem_dict.update( elem_dict.update(
@ -187,11 +188,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Accordion(open=False) as galore_tab: with gr.Accordion(open=False) as galore_tab:
with gr.Row(): with gr.Row():
use_galore = gr.Checkbox(scale=1) use_galore = gr.Checkbox()
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2) galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2) galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2) galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
galore_target = gr.Textbox(value="all", scale=3) galore_target = gr.Textbox(value="all")
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target}) input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update( elem_dict.update(
@ -228,29 +229,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
input_elems.update({output_dir, config_path})
output_elems = [output_box, process_bar]
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(
engine.runner.load_args,
[engine.manager.get_elem_by_id("top.lang"), config_path],
list(input_elems),
concurrency_limit=None,
)
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
elem_dict.update( elem_dict.update(
dict( dict(
cmd_preview_btn=cmd_preview_btn, cmd_preview_btn=cmd_preview_btn,
@ -267,15 +245,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
) )
output_box.change( input_elems.update({output_dir, config_path})
gen_plot, output_elems = [output_box, process_bar, loss_viewer]
[
engine.manager.get_elem_by_id("top.model_name"), cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
engine.manager.get_elem_by_id("top.finetuning_type"), arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
output_dir, arg_load_btn.click(
], engine.runner.load_args,
loss_viewer, [engine.manager.get_elem_by_id("top.lang"), config_path],
queue=False, list(input_elems) + [output_box],
concurrency_limit=None,
) )
start_btn.click(engine.runner.run_train, input_elems, output_elems)
stop_btn.click(engine.runner.set_abort)
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
list_adapters,
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
[reward_model],
queue=False,
).then(autoset_packing, [training_stage], [packing], queue=False)
return elem_dict return elem_dict

View File

@ -1344,6 +1344,11 @@ ALERTS = {
"ru": "Аргументы были сохранены по адресу: ", "ru": "Аргументы были сохранены по адресу: ",
"zh": "训练参数已保存至:", "zh": "训练参数已保存至:",
}, },
"info_config_loaded": {
"en": "Arguments have been restored.",
"ru": "Аргументы были восстановлены.",
"zh": "训练参数已载入。",
},
"info_loading": { "info_loading": {
"en": "Loading model...", "en": "Loading model...",
"ru": "Загрузка модели...", "ru": "Загрузка модели...",

View File

@ -2,7 +2,7 @@ import logging
import os import os
import time import time
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr import gradio as gr
import transformers import transformers
@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if TYPE_CHECKING: if TYPE_CHECKING:
@ -239,20 +239,22 @@ class Runner:
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error, gr.Slider(visible=False) yield {output_box: error}
else: else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.Slider(visible=False) yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]: def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error, gr.Slider(visible=False) yield {output_box: error}
else: else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
@ -261,54 +263,80 @@ class Runner:
self.thread.start() self.thread.start()
yield from self.monitor() yield from self.monitor()
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=True) yield from self._preview(data, do_train=True)
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
yield from self._preview(data, do_train=False) yield from self._preview(data, do_train=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=True) yield from self._launch(data, do_train=True)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]: def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]: def monitor(self) -> Generator[Dict[Component, Any], None, None]:
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
self.running = True self.running = True
lang = get("top.lang") lang = get("top.lang")
output_dir = get_save_dir( model_name = get("top.model_name")
get("top.model_name"), finetuning_type = get("top.finetuning_type")
get("top.finetuning_type"), output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
get("{}.output_dir".format("train" if self.do_train else "eval")), output_path = get_save_dir(model_name, finetuning_type, output_dir)
)
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
while self.thread is not None and self.thread.is_alive(): while self.thread is not None and self.thread.is_alive():
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False) yield {
output_box: ALERTS["info_aborting"][lang],
process_bar: gr.Slider(visible=False),
}
else: else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback) return_dict = {
output_box: self.logger_handler.log,
process_bar: update_process_bar(self.trainer_callback),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
yield return_dict
time.sleep(2) time.sleep(2)
if self.do_train: if self.do_train:
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang] finish_info = ALERTS["info_finished"][lang]
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_dir, "all_results.json")): if os.path.exists(os.path.join(output_path, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False) return_dict = {
output_box: self._finalize(lang, finish_info),
process_bar: gr.Slider(visible=False),
}
if self.do_train:
plot = gen_plot(output_path)
if plot is not None:
return_dict[loss_viewer] = plot
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]: yield return_dict
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:
gr.Warning(error) gr.Warning(error)
return error, gr.Slider(visible=False) return {output_box: error}
config_dict: Dict[str, Any] = {} config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")] lang = data[self.manager.get_elem_by_id("top.lang")]
@ -320,15 +348,16 @@ class Runner:
config_dict[elem_id] = value config_dict[elem_id] = value
save_path = save_args(config_path, config_dict) save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False) return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]: def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(config_path) config_dict = load_args(config_path)
if config_dict is None: if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang]) gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang} return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {} output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items(): for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value output_dict[self.manager.get_elem_by_id(elem_id)] = value

View File

@ -1,13 +1,12 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr import gradio as gr
from ..extras.packages import is_matplotlib_available from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
@ -36,7 +35,7 @@ def get_time() -> str:
def can_quantize(finetuning_type: str) -> "gr.Dropdown": def can_quantize(finetuning_type: str) -> "gr.Dropdown":
if finetuning_type != "lora": if finetuning_type != "lora":
return gr.Dropdown(value="None", interactive=False) return gr.Dropdown(value="none", interactive=False)
else: else:
return gr.Dropdown(interactive=True) return gr.Dropdown(interactive=True)
@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result) return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure": def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
if not base_model: log_file = os.path.join(output_path, "trainer_log.jsonl")
return if not os.path.isfile(log_file) or not is_matplotlib_available():
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return return
plt.close("all") plt.close("all")
@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
steps, losses = [], [] steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f: with open(log_file, "r", encoding="utf-8") as f:
for line in f: for line in f:
log_info = json.loads(line) log_info: Dict[str, Any] = json.loads(line)
if log_info.get("loss", None): if log_info.get("loss", None):
steps.append(log_info["current_steps"]) steps.append(log_info["current_steps"])
losses.append(log_info["loss"]) losses.append(log_info["loss"])
if len(losses) == 0: if len(losses) == 0:
return None return
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")