From 4f714ba314add77f8d4fa046e5ecf92565bee560 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 00:26:11 +0800 Subject: [PATCH] update webui Former-commit-id: 3a720aac669708d17152d4e96c2018b5ccc27b75 --- requirements.txt | 1 + src/llmtuner/chat/stream_chat.py | 1 - src/llmtuner/extras/template.py | 1 + src/llmtuner/webui/locales.py | 4 ++++ src/llmtuner/webui/runner.py | 23 ++++++++++++++++++----- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index d99ce326..9b74b21d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 trl>=0.4.7 +scipy sentencepiece tiktoken jieba diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 3be70616..79d3b92d 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -19,7 +19,6 @@ class ChatModel: self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.source_prefix = data_args.source_prefix self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) - self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen) def process_args( diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 402aa3b5..95d34d47 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -185,6 +185,7 @@ def get_template_and_fix_tokenizer( if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token + tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) return template diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 5962d64a..ba9a73bd 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -513,6 +513,10 @@ ALERTS = { "en": "Please provide export dir.", "zh": "请填写导出目录" }, + "err_failed": { + "en": "Failed.", + "zh": "训练出错。" + }, "info_aborting": { "en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……" diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 43ba908b..53c18828 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -3,6 +3,7 @@ import os import threading import time import transformers +from transformers.trainer import TRAINING_ARGS_NAME from typing import Generator, List, Optional, Tuple from llmtuner.extras.callbacks import LogCallback @@ -53,14 +54,14 @@ class Runner: return model_name_or_path, "", logger_handler, trainer_callback def finalize( - self, lang: str, finish_info: Optional[str] = None + self, lang: str, finish_info: str ) -> str: self.running = False torch_gc() if self.aborted: return ALERTS["info_aborted"][lang] else: - return finish_info if finish_info is not None else ALERTS["info_finished"][lang] + return finish_info def run_train( self, @@ -104,6 +105,8 @@ class Runner: else: checkpoint_dir = None + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) + args = dict( stage="sft", model_name_or_path=model_name_or_path, @@ -133,7 +136,7 @@ class Runner: lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), - output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) + output_dir=output_dir ) if dev_ratio > 1e-6: @@ -153,7 +156,12 @@ class Runner: else: yield format_info(logger_handler.log, trainer_callback) - yield self.finalize(lang) + if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME): + finish_info = ALERTS["info_finished"][lang] + else: + finish_info = ALERTS["err_failed"][lang] + + yield self.finalize(lang, finish_info) def run_eval( self, @@ -221,4 +229,9 @@ class Runner: else: yield format_info(logger_handler.log, trainer_callback) - yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) + if os.path.exists(os.path.join(output_dir, "all_results.json")): + finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) + else: + finish_info = ALERTS["err_failed"][lang] + + yield self.finalize(lang, finish_info)