mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
fix config, #1191
Former-commit-id: a6a04be2e6700050c64e59c00a86698f1098e7cc
This commit is contained in:
parent
e5e740b54d
commit
eafde5b73f
@ -12,7 +12,7 @@ fire
|
|||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
gradio>=3.36.0
|
gradio==3.38.0
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic==1.10.11
|
pydantic==1.10.11
|
||||||
fastapi==0.95.1
|
fastapi==0.95.1
|
||||||
|
@ -186,7 +186,7 @@ def get_train_args(
|
|||||||
|
|
||||||
# postprocess model_args
|
# postprocess model_args
|
||||||
model_args.compute_dtype = (
|
model_args.compute_dtype = (
|
||||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
|
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||||
)
|
)
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
@ -27,7 +27,6 @@ CKPT_NAMES = [
|
|||||||
ADAPTER_WEIGHTS_NAME,
|
ADAPTER_WEIGHTS_NAME,
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
ADAPTER_SAFE_WEIGHTS_NAME
|
||||||
]
|
]
|
||||||
CONFIG_CLASS = Dict[str, Union[str, Dict[str, str]]]
|
|
||||||
|
|
||||||
|
|
||||||
def get_save_dir(*args) -> os.PathLike:
|
def get_save_dir(*args) -> os.PathLike:
|
||||||
@ -38,7 +37,7 @@ def get_config_path() -> os.PathLike:
|
|||||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> CONFIG_CLASS:
|
def load_config() -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
@ -46,20 +45,20 @@ def load_config() -> CONFIG_CLASS:
|
|||||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||||
config: CONFIG_CLASS, lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
config["lang"] = lang or config["lang"]
|
user_config = load_config()
|
||||||
|
user_config["lang"] = lang or user_config["lang"]
|
||||||
if model_name:
|
if model_name:
|
||||||
config["last_model"] = model_name
|
user_config["last_model"] = model_name
|
||||||
config["path_dict"][model_name] = model_path
|
user_config["path_dict"][model_name] = model_path
|
||||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(config: Dict[str, Any], model_name: str) -> str:
|
def get_model_path(model_name: str) -> str:
|
||||||
return config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
user_config = load_config()
|
||||||
|
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||||
|
|
||||||
|
|
||||||
def get_module(model_name: str) -> str:
|
def get_module(model_name: str) -> str:
|
||||||
|
@ -17,10 +17,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
|
|
||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||||
elem_dict.update(dict(
|
|
||||||
info_box=info_box, load_btn=load_btn, unload_btn=unload_btn
|
|
||||||
))
|
|
||||||
|
|
||||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||||
|
@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
|
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
from llmtuner.extras.template import templates
|
||||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, load_config, save_config
|
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||||
from llmtuner.webui.utils import can_quantize
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
def create_top() -> Dict[str, "Component"]:
|
def create_top() -> Dict[str, "Component"]:
|
||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
config = gr.State(value=load_config())
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||||
@ -39,17 +38,17 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [config, model_name], [model_path], queue=False
|
get_model_path, [model_name], [model_path], queue=False
|
||||||
).then(
|
).then(
|
||||||
get_template, [model_name], [template], queue=False
|
get_template, [model_name], [template], queue=False
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
|
|
||||||
model_path.change(save_config, inputs=[config, lang, model_name, model_path])
|
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||||
).then(
|
).then(
|
||||||
can_quantize, [finetuning_type], [quantization_bit]
|
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_btn.click(
|
refresh_btn.click(
|
||||||
@ -57,7 +56,6 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
config=config,
|
|
||||||
lang=lang,
|
lang=lang,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
@ -3,7 +3,7 @@ from gradio.components import Component # cannot use TYPE_CHECKING here
|
|||||||
from typing import Any, Dict, Generator, Optional
|
from typing import Any, Dict, Generator, Optional
|
||||||
|
|
||||||
from llmtuner.webui.chatter import WebChatModel
|
from llmtuner.webui.chatter import WebChatModel
|
||||||
from llmtuner.webui.common import get_model_path, list_dataset, CONFIG_CLASS
|
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||||
from llmtuner.webui.locales import LOCALES
|
from llmtuner.webui.locales import LOCALES
|
||||||
from llmtuner.webui.manager import Manager
|
from llmtuner.webui.manager import Manager
|
||||||
from llmtuner.webui.runner import Runner
|
from llmtuner.webui.runner import Runner
|
||||||
@ -21,8 +21,9 @@ class Engine:
|
|||||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||||
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||||
|
|
||||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||||
lang = config.get("lang", None) or "en"
|
user_config = load_config()
|
||||||
|
lang = user_config.get("lang", None) or "en"
|
||||||
|
|
||||||
resume_dict = {
|
resume_dict = {
|
||||||
"top.lang": {"value": lang},
|
"top.lang": {"value": lang},
|
||||||
@ -33,9 +34,9 @@ class Engine:
|
|||||||
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||||
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||||
|
|
||||||
if config.get("last_model", None):
|
if user_config.get("last_model", None):
|
||||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
resume_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
resume_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||||
|
|
||||||
yield self._form_dict(resume_dict)
|
yield self._form_dict(resume_dict)
|
||||||
|
|
||||||
|
@ -9,12 +9,12 @@ from llmtuner.webui.components import (
|
|||||||
create_export_tab,
|
create_export_tab,
|
||||||
create_chat_box
|
create_chat_box
|
||||||
)
|
)
|
||||||
from llmtuner.webui.common import load_config, save_config
|
from llmtuner.webui.common import save_config
|
||||||
from llmtuner.webui.css import CSS
|
from llmtuner.webui.css import CSS
|
||||||
from llmtuner.webui.engine import Engine
|
from llmtuner.webui.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
require_version("gradio==3.38.0", "To fix: pip install gradio==3.38.0")
|
||||||
|
|
||||||
|
|
||||||
def create_ui() -> gr.Blocks:
|
def create_ui() -> gr.Blocks:
|
||||||
@ -23,9 +23,6 @@ def create_ui() -> gr.Blocks:
|
|||||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||||
engine.manager.all_elems["top"] = create_top()
|
engine.manager.all_elems["top"] = create_top()
|
||||||
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
||||||
config = engine.manager.get_elem("top.config")
|
|
||||||
model_name = engine.manager.get_elem("top.model_name")
|
|
||||||
model_path = engine.manager.get_elem("top.model_path")
|
|
||||||
|
|
||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||||
@ -39,13 +36,9 @@ def create_ui() -> gr.Blocks:
|
|||||||
with gr.Tab("Export"):
|
with gr.Tab("Export"):
|
||||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||||
|
|
||||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||||
|
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||||
lang.change(
|
lang.input(save_config, inputs=[lang], queue=False)
|
||||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
|
||||||
).then(
|
|
||||||
save_config, inputs=[config, lang, model_name, model_path]
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
@ -54,21 +47,15 @@ def create_web_demo() -> gr.Blocks:
|
|||||||
engine = Engine(pure_chat=True)
|
engine = Engine(pure_chat=True)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
config = gr.State(value=load_config())
|
|
||||||
lang = gr.Dropdown(choices=["en", "zh"])
|
lang = gr.Dropdown(choices=["en", "zh"])
|
||||||
|
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||||
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
|
|
||||||
|
|
||||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||||
|
|
||||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||||
|
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||||
lang.change(
|
lang.input(save_config, inputs=[lang], queue=False)
|
||||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
|
||||||
).then(
|
|
||||||
save_config, inputs=[config, lang]
|
|
||||||
)
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ class Manager:
|
|||||||
|
|
||||||
def get_base_elems(self):
|
def get_base_elems(self):
|
||||||
return {
|
return {
|
||||||
self.all_elems["top"]["config"],
|
|
||||||
self.all_elems["top"]["lang"],
|
self.all_elems["top"]["lang"],
|
||||||
self.all_elems["top"]["model_name"],
|
self.all_elems["top"]["model_name"],
|
||||||
self.all_elems["top"]["model_path"],
|
self.all_elems["top"]["model_path"],
|
||||||
|
@ -14,7 +14,7 @@ from llmtuner.extras.constants import TRAINING_STAGES
|
|||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import run_exp
|
from llmtuner.tuner import run_exp
|
||||||
from llmtuner.webui.common import get_module, get_save_dir
|
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
||||||
@ -74,6 +74,7 @@ class Runner:
|
|||||||
|
|
||||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||||
get = lambda name: data[self.manager.get_elem(name)]
|
get = lambda name: data[self.manager.get_elem(name)]
|
||||||
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([
|
||||||
@ -89,7 +90,7 @@ class Runner:
|
|||||||
model_name_or_path=get("top.model_path"),
|
model_name_or_path=get("top.model_path"),
|
||||||
do_train=True,
|
do_train=True,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
cache_dir=get("top.config").get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
@ -142,6 +143,7 @@ class Runner:
|
|||||||
|
|
||||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||||
get = lambda name: data[self.manager.get_elem(name)]
|
get = lambda name: data[self.manager.get_elem(name)]
|
||||||
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([
|
||||||
@ -160,7 +162,7 @@ class Runner:
|
|||||||
do_eval=True,
|
do_eval=True,
|
||||||
overwrite_cache=False,
|
overwrite_cache=False,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
cache_dir=get("top.config").get("cache_dir", None),
|
cache_dir=user_config.get("cache_dir", None),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user