[webui] improve webui & reasoning mode (#6778)

Former-commit-id: 45e68b9f092879dda55023ebbcd8cf4660e3045a
This commit is contained in:
hoshi-hiyouga 2025-01-31 00:09:21 +08:00 committed by GitHub
parent f143360ee6
commit 245de012ca
18 changed files with 570 additions and 409 deletions

View File

@ -216,16 +216,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@ -830,7 +829,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation

View File

@ -218,16 +218,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@ -832,7 +831,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用

View File

@ -719,6 +719,13 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"],
)
@ -729,17 +736,13 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"],
)
# copied from intern2 template
_register_template(
name="intern3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<|im_end|>"],
)

View File

@ -105,7 +105,7 @@ def register_model_group(
) -> None:
for name, path in models.items():
SUPPORTED_MODELS[name] = path
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision):
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or vision):
DEFAULT_TEMPLATE[name] = template
if vision:
VISION_MODELS.add(name)
@ -485,11 +485,11 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
},
"DeepSeek-V3-685B-Base": {
"DeepSeek-V3-671B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base",
},
"DeepSeek-V3-685B-Chat": {
"DeepSeek-V3-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
},
@ -517,11 +517,11 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
},
"DeepSeek-R1-671B-Zero": {
"DeepSeek-R1-671B-Chat-Zero": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Zero",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Zero",
},
"DeepSeek-R1-671B": {
},
"DeepSeek-R1-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
},
@ -845,20 +845,15 @@ register_model_group(
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
},
},
template="intern2",
)
register_model_group(
models={
"InternLM3-8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct",
},
},
template="intern3",
template="intern2",
)
register_model_group(
models={
"Jamba-v0.1": {

View File

@ -36,6 +36,30 @@ if is_gradio_available():
import gradio as gr
def _format_response(text: str, lang: str) -> str:
r"""
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if "<think>" not in text:
return text
text = text.replace("<think>", "")
result = text.split("</think>", maxsplit=1)
if len(result) == 1:
summary = ALERTS["info_thinking"][lang]
thought, answer = text, ""
else:
summary = ALERTS["info_thought"][lang]
thought, answer = result
return (
f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
)
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
@ -124,19 +148,26 @@ class WebChatModel(ChatModel):
torch_gc()
yield ALERTS["info_unloaded"][lang]
@staticmethod
def append(
self,
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
role: str,
query: str,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query
Output: infer.chatbot, infer.messages
"""
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
def stream(
self,
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
lang: str,
system: str,
tools: str,
image: Optional[Any],
@ -145,6 +176,12 @@ class WebChatModel(ChatModel):
top_p: float,
temperature: float,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot.append({"role": "assistant", "content": ""})
response = ""
for new_text in self.stream_chat(
@ -157,7 +194,6 @@ class WebChatModel(ChatModel):
top_p=top_p,
temperature=temperature,
):
new_text = '' if any(t in new_text for t in ('<think>', '</think>')) else new_text
response += new_text
if tools:
result = self.engine.template.extract_tool(response)
@ -166,12 +202,12 @@ class WebChatModel(ChatModel):
if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result
bot_text = _format_response(result, lang)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages

View File

@ -14,34 +14,28 @@
import json
import os
import signal
from collections import defaultdict
from typing import Any, Dict, Optional, Tuple
from datetime import datetime
from typing import Any, Dict, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
from ..extras import logging
from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG,
DEFAULT_TEMPLATE,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
TRAINING_STAGES,
TRAINING_ARGS,
VISION_MODELS,
DownloadSource,
)
from ..extras.misc import use_modelscope, use_openmind
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
@ -49,6 +43,21 @@ DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
@ -61,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
try:
with open(get_config_path(), encoding="utf-8") as f:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -92,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
if model_name and model_path:
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
with open(_get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f)
@ -120,20 +129,9 @@ def get_model_path(model_name: str) -> str:
return model_path
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
"""
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat model.
Gets the template name if the model is a chat/distill/instruct model.
"""
return DEFAULT_TEMPLATE.get(model_name, "default")
@ -145,24 +143,11 @@ def get_visual(model_name: str) -> bool:
return model_name in VISION_MODELS
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def get_time() -> str:
r"""
Lists all available checkpoints.
Gets current date and time.
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
@ -181,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Lists all available datasets in the dataset dir for the training stage.
Loads the training configuration from config path.
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(_clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema
from ..locales import ALERTS
if is_gradio_available():
@ -29,9 +30,27 @@ if TYPE_CHECKING:
from ..engine import Engine
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True)
messages = gr.State([])
@ -67,7 +86,7 @@ def create_chat_box(
[chatbot, messages, query],
).then(
engine.chatter.stream,
[chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature],
[chatbot, messages, lang, system, tools, image, video, max_new_tokens, top_p, temperature],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])

View File

@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)

View File

@ -15,7 +15,8 @@
from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_datasets
from ..common import DEFAULT_DATA_DIR
from ..control import list_datasets
from .data import create_preview_box

View File

@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config
from ..utils import can_quantize, can_quantize_to
from ..common import save_config
from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints
if is_gradio_available():

View File

@ -19,8 +19,8 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
from ..utils import change_stage, list_config_paths, list_output_dirs
from ..common import DEFAULT_DATA_DIR
from ..control import change_stage, list_checkpoints, list_config_paths, list_datasets, list_output_dirs
from .data import create_preview_box

View File

@ -0,0 +1,201 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from transformers.trainer_utils import get_last_checkpoint
from ..extras.constants import (
CHECKPOINT_NAMES,
PEFT_METHODS,
RUNNING_LOG,
STAGES_USE_PAIR_DATA,
TRAINER_LOG,
TRAINING_STAGES,
)
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
if is_gradio_available():
import gradio as gr
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Gets the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
elif quantization_method == QuantizationMethod.HQQ.value:
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
elif quantization_method == QuantizationMethod.EETQ.value:
available_bits = ["none", "8"]
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
Inputs: train.training_stage
Outputs: train.dataset, train.packing
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Inputs: top.model_name
Outputs: top.model_path, top.template
"""
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
If do_train is True:
Inputs: train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer
If do_train is False:
Inputs: eval.output_path
Outputs: eval.output_box, eval.progress_bar, None
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
"""
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)

View File

@ -20,6 +20,29 @@ CSS = r"""
border-radius: 100vh !important;
}
.thinking-summary {
padding: 8px !important;
}
.thinking-summary span {
border: 1px solid #e0e0e0 !important;
border-radius: 4px !important;
padding: 4px !important;
cursor: pointer !important;
font-size: 14px !important;
background: #333333 !important;
}
.thinking-container {
border-left: 2px solid #a6a6a6 !important;
padding-left: 10px !important;
margin: 4px 0 !important;
}
.thinking-container p {
color: #a6a6a6 !important;
}
.modal-box {
position: fixed !important;
top: 50%;

View File

@ -15,11 +15,10 @@
from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel
from .common import load_config
from .common import create_ds_config, get_time, load_config
from .locales import LOCALES
from .manager import Manager
from .runner import Runner
from .utils import create_ds_config, get_time
if TYPE_CHECKING:
@ -27,6 +26,10 @@ if TYPE_CHECKING:
class Engine:
r"""
A general engine to control the behaviors of Web UI.
"""
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
self.pure_chat = pure_chat
@ -38,7 +41,7 @@ class Engine:
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r"""
Gets the dict to update the components.
Updates gradio components according to the (elem_id, properties) mapping.
"""
output_dict: Dict["Component", "Component"] = {}
for elem_id, elem_attr in input_dict.items():
@ -48,9 +51,11 @@ class Engine:
return output_dict
def resume(self):
user_config = load_config() if not self.demo_mode else {}
r"""
Gets the initial value of gradio components and restores training status if necessary.
"""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat:
@ -74,6 +79,9 @@ class Engine:
yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str):
r"""
Updates the displayed language of gradio components.
"""
return {
elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter()

View File

@ -2786,6 +2786,20 @@ ALERTS = {
"ko": "모델이 언로드되었습니다.",
"ja": "モデルがアンロードされました。",
},
"info_thinking": {
"en": "🌀 Thinking...",
"ru": "🌀 Думаю...",
"zh": "🌀 思考中...",
"ko": "🌀 생각 중...",
"ja": "🌀 考えています...",
},
"info_thought": {
"en": "✅ Thought",
"ru": "✅ Думать закончено",
"zh": "✅ 思考完成",
"ko": "✅ 생각이 완료되었습니다",
"ja": "✅ 思考完了",
},
"info_exporting": {
"en": "Exporting model...",
"ru": "Экспорт модели...",

View File

@ -20,6 +20,10 @@ if TYPE_CHECKING:
class Manager:
r"""
A class to manage all the gradio components in Web UI.
"""
def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}

View File

@ -24,9 +24,20 @@ from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
from .common import (
DEFAULT_CACHE_DIR,
DEFAULT_CONFIG_DIR,
abort_process,
gen_cmd,
get_save_dir,
load_args,
load_config,
load_eval_results,
save_args,
save_cmd,
)
from .control import get_trainer_info
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
@ -40,6 +51,10 @@ if TYPE_CHECKING:
class Runner:
r"""
A class to manage the running status of the trainers.
"""
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
@ -57,6 +72,9 @@ class Runner:
abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""
Validates the configuration.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
@ -98,6 +116,9 @@ class Runner:
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
r"""
Cleans the cached memory and resets the runner.
"""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None
@ -108,6 +129,9 @@ class Runner:
return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the training arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
@ -268,6 +292,9 @@ class Runner:
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the evaluation arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
@ -319,6 +346,9 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
r"""
Previews the training commands.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
@ -329,6 +359,9 @@ class Runner:
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
r"""
Starts the training process.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
@ -339,7 +372,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
os.makedirs(args["output_dir"], exist_ok=True)
save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data))
save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._build_config_dict(data))
env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1"
@ -350,7 +383,10 @@ class Runner:
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
yield from self.monitor()
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds a dictionary containing the current training configuration.
"""
config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
@ -373,6 +409,9 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self):
r"""
Monitors the training progress and logs.
"""
self.aborted = False
self.running = True
@ -416,7 +455,7 @@ class Runner:
finish_info = ALERTS["err_failed"][lang]
else:
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
@ -427,6 +466,9 @@ class Runner:
yield return_dict
def save_args(self, data):
r"""
Saves the training configuration to config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
@ -438,10 +480,13 @@ class Runner:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
save_args(save_path, self._form_config_dict(data))
save_args(save_path, self._build_config_dict(data))
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str):
r"""
Loads the training configuration from config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
if config_dict is None:
@ -455,6 +500,9 @@ class Runner:
return output_dict
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
r"""
Restore the training status if output_dir exists.
"""
output_box = self.manager.get_elem_by_id("train.output_box")
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):

View File

@ -1,304 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import signal
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
if is_gradio_available():
import gradio as gr
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = psutil.Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Returns the available quantization bits.
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
elif quantization_method == QuantizationMethod.HQQ.value:
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
elif quantization_method == QuantizationMethod.EETQ.value:
available_bits = ["none", "8"]
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]):
r"""
Saves arguments.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)