mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
update webui
Former-commit-id: 3a720aac669708d17152d4e96c2018b5ccc27b75
This commit is contained in:
parent
77aa9853fb
commit
4f714ba314
@ -4,6 +4,7 @@ datasets>=2.12.0
|
|||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.4.0
|
peft>=0.4.0
|
||||||
trl>=0.4.7
|
trl>=0.4.7
|
||||||
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
tiktoken
|
tiktoken
|
||||||
jieba
|
jieba
|
||||||
|
@ -19,7 +19,6 @@ class ChatModel:
|
|||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
|
||||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
|
@ -185,6 +185,7 @@ def get_template_and_fix_tokenizer(
|
|||||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
@ -513,6 +513,10 @@ ALERTS = {
|
|||||||
"en": "Please provide export dir.",
|
"en": "Please provide export dir.",
|
||||||
"zh": "请填写导出目录"
|
"zh": "请填写导出目录"
|
||||||
},
|
},
|
||||||
|
"err_failed": {
|
||||||
|
"en": "Failed.",
|
||||||
|
"zh": "训练出错。"
|
||||||
|
},
|
||||||
"info_aborting": {
|
"info_aborting": {
|
||||||
"en": "Aborted, wait for terminating...",
|
"en": "Aborted, wait for terminating...",
|
||||||
"zh": "训练中断,正在等待线程结束……"
|
"zh": "训练中断,正在等待线程结束……"
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from typing import Generator, List, Optional, Tuple
|
from typing import Generator, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
@ -53,14 +54,14 @@ class Runner:
|
|||||||
return model_name_or_path, "", logger_handler, trainer_callback
|
return model_name_or_path, "", logger_handler, trainer_callback
|
||||||
|
|
||||||
def finalize(
|
def finalize(
|
||||||
self, lang: str, finish_info: Optional[str] = None
|
self, lang: str, finish_info: str
|
||||||
) -> str:
|
) -> str:
|
||||||
self.running = False
|
self.running = False
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
return ALERTS["info_aborted"][lang]
|
return ALERTS["info_aborted"][lang]
|
||||||
else:
|
else:
|
||||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
return finish_info
|
||||||
|
|
||||||
def run_train(
|
def run_train(
|
||||||
self,
|
self,
|
||||||
@ -104,6 +105,8 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
|
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
@ -133,7 +136,7 @@ class Runner:
|
|||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
if dev_ratio > 1e-6:
|
if dev_ratio > 1e-6:
|
||||||
@ -153,7 +156,12 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
yield format_info(logger_handler.log, trainer_callback)
|
yield format_info(logger_handler.log, trainer_callback)
|
||||||
|
|
||||||
yield self.finalize(lang)
|
if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME):
|
||||||
|
finish_info = ALERTS["info_finished"][lang]
|
||||||
|
else:
|
||||||
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
|
||||||
|
yield self.finalize(lang, finish_info)
|
||||||
|
|
||||||
def run_eval(
|
def run_eval(
|
||||||
self,
|
self,
|
||||||
@ -221,4 +229,9 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
yield format_info(logger_handler.log, trainer_callback)
|
yield format_info(logger_handler.log, trainer_callback)
|
||||||
|
|
||||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||||
|
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||||
|
else:
|
||||||
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
|
||||||
|
yield self.finalize(lang, finish_info)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user