mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
fix bug in latest gradio
Former-commit-id: 5ddcecda50ccff93d51bebc9ac72c2a0dd483e9b
This commit is contained in:
parent
a6d347726f
commit
54a4a8217a
@ -4,7 +4,7 @@ datasets>=2.14.3
|
|||||||
accelerate>=0.27.2
|
accelerate>=0.27.2
|
||||||
peft>=0.10.0
|
peft>=0.10.0
|
||||||
trl>=0.8.1
|
trl>=0.8.1
|
||||||
gradio>4.0.0,<=4.21.0
|
gradio>=4.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
@ -1,114 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Converts the InternLM2 model in the same format as LLaMA2.
|
|
||||||
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output
|
|
||||||
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import fire
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers.modeling_utils import (
|
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
shard_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
|
||||||
|
|
||||||
|
|
||||||
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
|
||||||
internlm2_config_dict: Dict[str, Any] = json.load(f)
|
|
||||||
|
|
||||||
internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
|
||||||
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
|
|
||||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
|
||||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
|
||||||
internlm2_state_dict.update(shard_weight)
|
|
||||||
|
|
||||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
|
||||||
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
|
|
||||||
if "output" in key:
|
|
||||||
llama2_state_dict[key.replace("output", "lm_head")] = value
|
|
||||||
elif "tok_embeddings" in key:
|
|
||||||
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
|
|
||||||
elif "wqkv" in key:
|
|
||||||
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
|
||||||
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
|
||||||
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
|
||||||
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
|
|
||||||
q_size : q_size + kv_size, ...
|
|
||||||
]
|
|
||||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
|
|
||||||
elif "wo" in key:
|
|
||||||
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
|
|
||||||
elif "attention_norm" in key:
|
|
||||||
llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
|
|
||||||
elif "ffn_norm" in key:
|
|
||||||
llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
|
|
||||||
elif "w1" in key:
|
|
||||||
llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
|
|
||||||
elif "w2" in key:
|
|
||||||
llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
|
|
||||||
elif "w3" in key:
|
|
||||||
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
|
|
||||||
else:
|
|
||||||
llama2_state_dict[key] = value
|
|
||||||
|
|
||||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
|
||||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
|
||||||
|
|
||||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
|
||||||
if save_safetensors:
|
|
||||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
|
||||||
else:
|
|
||||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
|
||||||
|
|
||||||
if index is None:
|
|
||||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
|
||||||
else:
|
|
||||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
|
||||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
|
||||||
json.dump(index, f, indent=2, sort_keys=True)
|
|
||||||
print("Model weights saved in {}".format(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def save_config(input_dir: str, output_dir: str):
|
|
||||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
|
||||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
|
||||||
|
|
||||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
|
||||||
llama2_config_dict.pop("auto_map", None)
|
|
||||||
llama2_config_dict.pop("bias", None)
|
|
||||||
llama2_config_dict.pop("rope_scaling", None)
|
|
||||||
llama2_config_dict["model_type"] = "llama"
|
|
||||||
|
|
||||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
|
||||||
json.dump(llama2_config_dict, f, indent=2)
|
|
||||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
|
||||||
|
|
||||||
|
|
||||||
def llamafy_internlm2(
|
|
||||||
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
os.makedirs(output_dir, exist_ok=False)
|
|
||||||
except Exception as e:
|
|
||||||
raise print("Output dir already exists", e)
|
|
||||||
|
|
||||||
save_weight(input_dir, output_dir, shard_size, save_safetensors)
|
|
||||||
save_config(input_dir, output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fire.Fire(llamafy_internlm2)
|
|
@ -66,7 +66,7 @@ def check_dependencies() -> None:
|
|||||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0")
|
||||||
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||||
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
@ -21,8 +21,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
|
||||||
|
|
||||||
input_elems.update({dataset_dir, dataset})
|
input_elems.update({dataset_dir, dataset})
|
||||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
|
|
||||||
@ -50,7 +48,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
stop_btn = gr.Button(variant="stop")
|
stop_btn = gr.Button(variant="stop")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||||
process_bar = gr.Slider(visible=False, interactive=False)
|
process_bar = gr.Slider(visible=False, interactive=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -73,4 +71,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
stop_btn.click(engine.runner.set_abort)
|
stop_btn.click(engine.runner.set_abort)
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||||
|
|
||||||
|
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
@ -6,7 +6,6 @@ from transformers.trainer_utils import SchedulerType
|
|||||||
from ...extras.constants import TRAINING_STAGES
|
from ...extras.constants import TRAINING_STAGES
|
||||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||||
from ..components.data import create_preview_box
|
from ..components.data import create_preview_box
|
||||||
from ..utils import gen_plot
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -24,7 +23,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=1
|
||||||
)
|
)
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=2, allow_custom_value=True)
|
dataset = gr.Dropdown(multiselect=True, scale=4, allow_custom_value=True)
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
@ -121,8 +120,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Accordion(open=False) as freeze_tab:
|
with gr.Accordion(open=False) as freeze_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1)
|
||||||
name_module_trainable = gr.Textbox(value="all", scale=3)
|
name_module_trainable = gr.Textbox(value="all")
|
||||||
|
|
||||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@ -140,8 +139,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
create_new_adapter = gr.Checkbox()
|
create_new_adapter = gr.Checkbox()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_rslora = gr.Checkbox(scale=1)
|
with gr.Column(scale=1):
|
||||||
use_dora = gr.Checkbox(scale=1)
|
use_rslora = gr.Checkbox()
|
||||||
|
use_dora = gr.Checkbox()
|
||||||
|
|
||||||
lora_target = gr.Textbox(scale=2)
|
lora_target = gr.Textbox(scale=2)
|
||||||
additional_target = gr.Textbox(scale=2)
|
additional_target = gr.Textbox(scale=2)
|
||||||
|
|
||||||
@ -175,10 +176,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Accordion(open=False) as rlhf_tab:
|
with gr.Accordion(open=False) as rlhf_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01)
|
||||||
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
orpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||||
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=2)
|
reward_model = gr.Dropdown(multiselect=True, allow_custom_value=True)
|
||||||
|
|
||||||
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
input_elems.update({dpo_beta, dpo_ftx, orpo_beta, reward_model})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@ -187,11 +188,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Accordion(open=False) as galore_tab:
|
with gr.Accordion(open=False) as galore_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
use_galore = gr.Checkbox(scale=1)
|
use_galore = gr.Checkbox()
|
||||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1)
|
||||||
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
|
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1)
|
||||||
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
|
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01)
|
||||||
galore_target = gr.Textbox(value="all", scale=3)
|
galore_target = gr.Textbox(value="all")
|
||||||
|
|
||||||
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@ -228,29 +229,6 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
loss_viewer = gr.Plot()
|
loss_viewer = gr.Plot()
|
||||||
|
|
||||||
input_elems.update({output_dir, config_path})
|
|
||||||
output_elems = [output_box, process_bar]
|
|
||||||
|
|
||||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
|
||||||
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
|
||||||
arg_load_btn.click(
|
|
||||||
engine.runner.load_args,
|
|
||||||
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
|
||||||
list(input_elems),
|
|
||||||
concurrency_limit=None,
|
|
||||||
)
|
|
||||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
|
||||||
stop_btn.click(engine.runner.set_abort)
|
|
||||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
|
||||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
|
||||||
list_adapters,
|
|
||||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
|
||||||
[reward_model],
|
|
||||||
queue=False,
|
|
||||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
|
||||||
|
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
dict(
|
dict(
|
||||||
cmd_preview_btn=cmd_preview_btn,
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
@ -267,15 +245,27 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
output_box.change(
|
input_elems.update({output_dir, config_path})
|
||||||
gen_plot,
|
output_elems = [output_box, process_bar, loss_viewer]
|
||||||
[
|
|
||||||
engine.manager.get_elem_by_id("top.model_name"),
|
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
|
||||||
output_dir,
|
arg_load_btn.click(
|
||||||
],
|
engine.runner.load_args,
|
||||||
loss_viewer,
|
[engine.manager.get_elem_by_id("top.lang"), config_path],
|
||||||
queue=False,
|
list(input_elems) + [output_box],
|
||||||
|
concurrency_limit=None,
|
||||||
)
|
)
|
||||||
|
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||||
|
stop_btn.click(engine.runner.set_abort)
|
||||||
|
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||||
|
|
||||||
|
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||||
|
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||||
|
list_adapters,
|
||||||
|
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||||
|
[reward_model],
|
||||||
|
queue=False,
|
||||||
|
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||||
|
|
||||||
return elem_dict
|
return elem_dict
|
||||||
|
@ -1344,6 +1344,11 @@ ALERTS = {
|
|||||||
"ru": "Аргументы были сохранены по адресу: ",
|
"ru": "Аргументы были сохранены по адресу: ",
|
||||||
"zh": "训练参数已保存至:",
|
"zh": "训练参数已保存至:",
|
||||||
},
|
},
|
||||||
|
"info_config_loaded": {
|
||||||
|
"en": "Arguments have been restored.",
|
||||||
|
"ru": "Аргументы были восстановлены.",
|
||||||
|
"zh": "训练参数已载入。",
|
||||||
|
},
|
||||||
"info_loading": {
|
"info_loading": {
|
||||||
"en": "Loading model...",
|
"en": "Loading model...",
|
||||||
"ru": "Загрузка модели...",
|
"ru": "Загрузка модели...",
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import transformers
|
import transformers
|
||||||
@ -17,7 +17,7 @@ from ..extras.misc import get_device_count, torch_gc
|
|||||||
from ..train import run_exp
|
from ..train import run_exp
|
||||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
from .utils import gen_cmd, get_eval_results, update_process_bar
|
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -239,20 +239,22 @@ class Runner:
|
|||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, str], None, None]:
|
||||||
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||||
error = self._initialize(data, do_train, from_preview=True)
|
error = self._initialize(data, do_train, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error, gr.Slider(visible=False)
|
yield {output_box: error}
|
||||||
else:
|
else:
|
||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
yield gen_cmd(args), gr.Slider(visible=False)
|
yield {output_box: gen_cmd(args)}
|
||||||
|
|
||||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict[Component, Any], None, None]:
|
||||||
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
|
||||||
error = self._initialize(data, do_train, from_preview=False)
|
error = self._initialize(data, do_train, from_preview=False)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
yield error, gr.Slider(visible=False)
|
yield {output_box: error}
|
||||||
else:
|
else:
|
||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
@ -261,54 +263,80 @@ class Runner:
|
|||||||
self.thread.start()
|
self.thread.start()
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
|
|
||||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
def preview_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||||
yield from self._preview(data, do_train=True)
|
yield from self._preview(data, do_train=True)
|
||||||
|
|
||||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, str], None, None]:
|
||||||
yield from self._preview(data, do_train=False)
|
yield from self._preview(data, do_train=False)
|
||||||
|
|
||||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
def run_train(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||||
yield from self._launch(data, do_train=True)
|
yield from self._launch(data, do_train=True)
|
||||||
|
|
||||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
def run_eval(self, data: Dict[Component, Any]) -> Generator[Dict[Component, Any], None, None]:
|
||||||
yield from self._launch(data, do_train=False)
|
yield from self._launch(data, do_train=False)
|
||||||
|
|
||||||
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
def monitor(self) -> Generator[Dict[Component, Any], None, None]:
|
||||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
lang = get("top.lang")
|
lang = get("top.lang")
|
||||||
output_dir = get_save_dir(
|
model_name = get("top.model_name")
|
||||||
get("top.model_name"),
|
finetuning_type = get("top.finetuning_type")
|
||||||
get("top.finetuning_type"),
|
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
|
||||||
get("{}.output_dir".format("train" if self.do_train else "eval")),
|
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||||
)
|
|
||||||
|
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
|
||||||
|
process_bar = self.manager.get_elem_by_id("{}.process_bar".format("train" if self.do_train else "eval"))
|
||||||
|
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
|
||||||
|
|
||||||
while self.thread is not None and self.thread.is_alive():
|
while self.thread is not None and self.thread.is_alive():
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
|
yield {
|
||||||
|
output_box: ALERTS["info_aborting"][lang],
|
||||||
|
process_bar: gr.Slider(visible=False),
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
return_dict = {
|
||||||
|
output_box: self.logger_handler.log,
|
||||||
|
process_bar: update_process_bar(self.trainer_callback),
|
||||||
|
}
|
||||||
|
if self.do_train:
|
||||||
|
plot = gen_plot(output_path)
|
||||||
|
if plot is not None:
|
||||||
|
return_dict[loss_viewer] = plot
|
||||||
|
|
||||||
|
yield return_dict
|
||||||
|
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
|
|
||||||
if self.do_train:
|
if self.do_train:
|
||||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)):
|
||||||
finish_info = ALERTS["info_finished"][lang]
|
finish_info = ALERTS["info_finished"][lang]
|
||||||
else:
|
else:
|
||||||
finish_info = ALERTS["err_failed"][lang]
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
else:
|
else:
|
||||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
if os.path.exists(os.path.join(output_path, "all_results.json")):
|
||||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
|
||||||
else:
|
else:
|
||||||
finish_info = ALERTS["err_failed"][lang]
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
|
||||||
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
|
return_dict = {
|
||||||
|
output_box: self._finalize(lang, finish_info),
|
||||||
|
process_bar: gr.Slider(visible=False),
|
||||||
|
}
|
||||||
|
if self.do_train:
|
||||||
|
plot = gen_plot(output_path)
|
||||||
|
if plot is not None:
|
||||||
|
return_dict[loss_viewer] = plot
|
||||||
|
|
||||||
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
|
yield return_dict
|
||||||
|
|
||||||
|
def save_args(self, data: Dict[Component, Any]) -> Dict[Component, str]:
|
||||||
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
error = self._initialize(data, do_train=True, from_preview=True)
|
error = self._initialize(data, do_train=True, from_preview=True)
|
||||||
if error:
|
if error:
|
||||||
gr.Warning(error)
|
gr.Warning(error)
|
||||||
return error, gr.Slider(visible=False)
|
return {output_box: error}
|
||||||
|
|
||||||
config_dict: Dict[str, Any] = {}
|
config_dict: Dict[str, Any] = {}
|
||||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||||
@ -320,15 +348,16 @@ class Runner:
|
|||||||
config_dict[elem_id] = value
|
config_dict[elem_id] = value
|
||||||
|
|
||||||
save_path = save_args(config_path, config_dict)
|
save_path = save_args(config_path, config_dict)
|
||||||
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
|
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||||
|
|
||||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||||
|
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||||
config_dict = load_args(config_path)
|
config_dict = load_args(config_path)
|
||||||
if config_dict is None:
|
if config_dict is None:
|
||||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||||
return {self.manager.get_elem_by_id("top.lang"): lang}
|
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||||
|
|
||||||
output_dict: Dict["Component", Any] = {}
|
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
|
||||||
for elem_id, value in config_dict.items():
|
for elem_id, value in config_dict.items():
|
||||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||||
|
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from ..extras.packages import is_matplotlib_available
|
from ..extras.packages import is_matplotlib_available
|
||||||
from ..extras.ploting import smooth
|
from ..extras.ploting import smooth
|
||||||
from .common import get_save_dir
|
|
||||||
from .locales import ALERTS
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +35,7 @@ def get_time() -> str:
|
|||||||
|
|
||||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||||
if finetuning_type != "lora":
|
if finetuning_type != "lora":
|
||||||
return gr.Dropdown(value="None", interactive=False)
|
return gr.Dropdown(value="none", interactive=False)
|
||||||
else:
|
else:
|
||||||
return gr.Dropdown(interactive=True)
|
return gr.Dropdown(interactive=True)
|
||||||
|
|
||||||
@ -74,11 +73,9 @@ def get_eval_results(path: os.PathLike) -> str:
|
|||||||
return "```json\n{}\n```\n".format(result)
|
return "```json\n{}\n```\n".format(result)
|
||||||
|
|
||||||
|
|
||||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
|
def gen_plot(output_path: str) -> Optional["matplotlib.figure.Figure"]:
|
||||||
if not base_model:
|
log_file = os.path.join(output_path, "trainer_log.jsonl")
|
||||||
return
|
if not os.path.isfile(log_file) or not is_matplotlib_available():
|
||||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
|
||||||
if not os.path.isfile(log_file):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
@ -88,13 +85,13 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplot
|
|||||||
steps, losses = [], []
|
steps, losses = [], []
|
||||||
with open(log_file, "r", encoding="utf-8") as f:
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
log_info = json.loads(line)
|
log_info: Dict[str, Any] = json.loads(line)
|
||||||
if log_info.get("loss", None):
|
if log_info.get("loss", None):
|
||||||
steps.append(log_info["current_steps"])
|
steps.append(log_info["current_steps"])
|
||||||
losses.append(log_info["loss"])
|
losses.append(log_info["loss"])
|
||||||
|
|
||||||
if len(losses) == 0:
|
if len(losses) == 0:
|
||||||
return None
|
return
|
||||||
|
|
||||||
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
||||||
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user