[webui] improve webui & reasoning mode (#6778)

Former-commit-id: 45e68b9f092879dda55023ebbcd8cf4660e3045a
This commit is contained in:
hoshi-hiyouga 2025-01-31 00:09:21 +08:00 committed by GitHub
parent f143360ee6
commit 245de012ca
18 changed files with 570 additions and 409 deletions

View File

@ -216,16 +216,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@ -830,7 +829,7 @@ If you have a project that should be incorporated, please contact via email or c
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## Citation ## Citation

View File

@ -218,16 +218,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/685B | deepseek3 | | [DeepSeek 2.5/3](https://huggingface.co/deepseek-ai) | 236B/671B | deepseek3 |
| [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseekr1 | | [DeepSeek R1](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternLM3](https://huggingface.co/internlm) | 8B | intern3 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
@ -832,7 +831,7 @@ swanlab_run_name: test_run # 可选
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) 使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [GPT-2](https://github.com/openai/gpt-2/blob/master/LICENSE) / [Granite](LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3/Phi-4](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [Skywork](https://huggingface.co/Skywork/Skywork-13B-base/blob/main/Skywork%20Community%20License.pdf) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [TeleChat2](https://huggingface.co/Tele-AI/telechat-7B/blob/main/TeleChat%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
## 引用 ## 引用

View File

@ -719,6 +719,13 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]), format_assistant=StringFormatter(slots=["{{content}}<eoa>\n"]),
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language "
"chosen by the user such as English and 中文."
),
stop_words=["<eoa>"], stop_words=["<eoa>"],
) )
@ -729,17 +736,13 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"], default_system=(
) "You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory "
"(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
# copied from intern2 template "- InternLM (书生·浦语) can understand and communicate fluently in the language "
_register_template( "chosen by the user such as English and 中文."
name="intern3", ),
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
) )

View File

@ -105,7 +105,7 @@ def register_model_group(
) -> None: ) -> None:
for name, path in models.items(): for name, path in models.items():
SUPPORTED_MODELS[name] = path SUPPORTED_MODELS[name] = path
if template is not None and (any(suffix in name for suffix in ("-Chat", "-Instruct")) or vision): if template is not None and (any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or vision):
DEFAULT_TEMPLATE[name] = template DEFAULT_TEMPLATE[name] = template
if vision: if vision:
VISION_MODELS.add(name) VISION_MODELS.add(name)
@ -485,11 +485,11 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2.5-1210",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2.5-1210",
}, },
"DeepSeek-V3-685B-Base": { "DeepSeek-V3-671B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3-Base",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3-Base",
}, },
"DeepSeek-V3-685B-Chat": { "DeepSeek-V3-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V3",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V3",
}, },
@ -517,11 +517,11 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
}, },
"DeepSeek-R1-671B-Zero": { "DeepSeek-R1-671B-Chat-Zero": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Zero", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1-Zero",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Zero", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1-Zero",
}, },
"DeepSeek-R1-671B": { "DeepSeek-R1-671B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-R1",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-R1",
}, },
@ -845,20 +845,15 @@ register_model_group(
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat",
DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat", DownloadSource.OPENMIND: "Intern/internlm2_5-20b-chat",
}, },
},
template="intern2",
)
register_model_group(
models={
"InternLM3-8B-Chat": { "InternLM3-8B-Chat": {
DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct", DownloadSource.DEFAULT: "internlm/internlm3-8b-instruct",
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm3-8b-instruct",
}, },
}, },
template="intern3", template="intern2",
) )
register_model_group( register_model_group(
models={ models={
"Jamba-v0.1": { "Jamba-v0.1": {

View File

@ -36,6 +36,30 @@ if is_gradio_available():
import gradio as gr import gradio as gr
def _format_response(text: str, lang: str) -> str:
r"""
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if "<think>" not in text:
return text
text = text.replace("<think>", "")
result = text.split("</think>", maxsplit=1)
if len(result) == 1:
summary = ALERTS["info_thinking"][lang]
thought, answer = text, ""
else:
summary = ALERTS["info_thought"][lang]
thought, answer = result
return (
f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
)
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
@ -124,19 +148,26 @@ class WebChatModel(ChatModel):
torch_gc() torch_gc()
yield ALERTS["info_unloaded"][lang] yield ALERTS["info_unloaded"][lang]
@staticmethod
def append( def append(
self,
chatbot: List[Dict[str, str]], chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
role: str, role: str,
query: str, query: str,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query
Output: infer.chatbot, infer.messages
"""
return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], "" return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
def stream( def stream(
self, self,
chatbot: List[Dict[str, str]], chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
lang: str,
system: str, system: str,
tools: str, tools: str,
image: Optional[Any], image: Optional[Any],
@ -145,6 +176,12 @@ class WebChatModel(ChatModel):
top_p: float, top_p: float,
temperature: float, temperature: float,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
chatbot.append({"role": "assistant", "content": ""}) chatbot.append({"role": "assistant", "content": ""})
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
@ -157,7 +194,6 @@ class WebChatModel(ChatModel):
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
): ):
new_text = '' if any(t in new_text for t in ('<think>', '</think>')) else new_text
response += new_text response += new_text
if tools: if tools:
result = self.engine.template.extract_tool(response) result = self.engine.template.extract_tool(response)
@ -166,12 +202,12 @@ class WebChatModel(ChatModel):
if isinstance(result, list): if isinstance(result, list):
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result] tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) tool_calls = json.dumps(tool_calls, ensure_ascii=False)
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
bot_text = "```json\n" + tool_calls + "\n```" bot_text = "```json\n" + tool_calls + "\n```"
else: else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result bot_text = _format_response(result, lang)
chatbot[-1] = {"role": "assistant", "content": bot_text} chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages yield chatbot, output_messages

View File

@ -14,34 +14,28 @@
import json import json
import os import os
import signal
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional, Tuple from datetime import datetime
from typing import Any, Dict, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load from yaml import safe_dump, safe_load
from ..extras import logging from ..extras import logging
from ..extras.constants import ( from ..extras.constants import (
CHECKPOINT_NAMES,
DATA_CONFIG, DATA_CONFIG,
DEFAULT_TEMPLATE, DEFAULT_TEMPLATE,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS, SUPPORTED_MODELS,
TRAINING_STAGES, TRAINING_ARGS,
VISION_MODELS, VISION_MODELS,
DownloadSource, DownloadSource,
) )
from ..extras.misc import use_modelscope, use_openmind from ..extras.misc import use_modelscope, use_openmind
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config" DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data" DEFAULT_DATA_DIR = "data"
@ -49,6 +43,21 @@ DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml" USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def get_save_dir(*paths: str) -> os.PathLike: def get_save_dir(*paths: str) -> os.PathLike:
r""" r"""
Gets the path to saved model checkpoints. Gets the path to saved model checkpoints.
@ -61,19 +70,19 @@ def get_save_dir(*paths: str) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *paths) return os.path.join(DEFAULT_SAVE_DIR, *paths)
def get_config_path() -> os.PathLike: def _get_config_path() -> os.PathLike:
r""" r"""
Gets the path to user config. Gets the path to user config.
""" """
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]: def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r""" r"""
Loads user config if exists. Loads user config if exists.
""" """
try: try:
with open(get_config_path(), encoding="utf-8") as f: with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f) return safe_load(f)
except Exception: except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@ -92,7 +101,7 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
if model_name and model_path: if model_name and model_path:
user_config["path_dict"][model_name] = model_path user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(_get_config_path(), "w", encoding="utf-8") as f:
safe_dump(user_config, f) safe_dump(user_config, f)
@ -120,20 +129,9 @@ def get_model_path(model_name: str) -> str:
return model_path return model_path
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Returns:
model_path (str)
template (str)
"""
return get_model_path(model_name), get_template(model_name)
def get_template(model_name: str) -> str: def get_template(model_name: str) -> str:
r""" r"""
Gets the template name if the model is a chat model. Gets the template name if the model is a chat/distill/instruct model.
""" """
return DEFAULT_TEMPLATE.get(model_name, "default") return DEFAULT_TEMPLATE.get(model_name, "default")
@ -145,24 +143,11 @@ def get_visual(model_name: str) -> bool:
return model_name in VISION_MODELS return model_name in VISION_MODELS
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": def get_time() -> str:
r""" r"""
Lists all available checkpoints. Gets current date and time.
""" """
checkpoints = [] return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
@ -181,11 +166,135 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {} return {}
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r""" r"""
Lists all available datasets in the dataset dir for the training stage. Loads the training configuration from config path.
""" """
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) try:
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA with open(config_path, encoding="utf-8") as f:
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] return safe_load(f)
return gr.Dropdown(choices=datasets) except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(_clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
from ...data import Role from ...data import Role
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..utils import check_json_schema from ..locales import ALERTS
if is_gradio_available(): if is_gradio_available():
@ -29,9 +30,27 @@ if TYPE_CHECKING:
from ..engine import Engine from ..engine import Engine
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def create_chat_box( def create_chat_box(
engine: "Engine", visible: bool = False engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Component", "Component", Dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box: with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True) chatbot = gr.Chatbot(type="messages", show_copy_button=True)
messages = gr.State([]) messages = gr.State([])
@ -67,7 +86,7 @@ def create_chat_box(
[chatbot, messages, query], [chatbot, messages, query],
).then( ).then(
engine.chatter.stream, engine.chatter.stream,
[chatbot, messages, system, tools, image, video, max_new_tokens, top_p, temperature], [chatbot, messages, lang, system, tools, image, video, max_new_tokens, top_p, temperature],
[chatbot, messages], [chatbot, messages],
) )
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])

View File

@ -40,6 +40,9 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button": def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
try: try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
@ -67,6 +70,9 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]: def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)

View File

@ -15,7 +15,8 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_datasets from ..common import DEFAULT_DATA_DIR
from ..control import list_datasets
from .data import create_preview_box from .data import create_preview_box

View File

@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
from ...data import TEMPLATES from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import get_model_info, list_checkpoints, save_config from ..common import save_config
from ..utils import can_quantize, can_quantize_to from ..control import can_quantize, can_quantize_to, get_model_info, list_checkpoints
if is_gradio_available(): if is_gradio_available():

View File

@ -19,8 +19,8 @@ from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ...extras.misc import get_device_count from ...extras.misc import get_device_count
from ...extras.packages import is_gradio_available from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets from ..common import DEFAULT_DATA_DIR
from ..utils import change_stage, list_config_paths, list_output_dirs from ..control import change_stage, list_checkpoints, list_config_paths, list_datasets, list_output_dirs
from .data import create_preview_box from .data import create_preview_box

View File

@ -0,0 +1,201 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from transformers.trainer_utils import get_last_checkpoint
from ..extras.constants import (
CHECKPOINT_NAMES,
PEFT_METHODS,
RUNNING_LOG,
STAGES_USE_PAIR_DATA,
TRAINER_LOG,
TRAINING_STAGES,
)
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CONFIG_DIR, DEFAULT_DATA_DIR, get_model_path, get_save_dir, get_template, load_dataset_info
if is_gradio_available():
import gradio as gr
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Gets the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
elif quantization_method == QuantizationMethod.HQQ.value:
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
elif quantization_method == QuantizationMethod.EETQ.value:
available_bits = ["none", "8"]
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
Inputs: train.training_stage
Outputs: train.dataset, train.packing
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
Inputs: top.model_name
Outputs: top.model_path, top.template
"""
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
If do_train is True:
Inputs: train.output_path
Outputs: train.output_box, train.progress_bar, train.loss_viewer
If do_train is False:
Inputs: eval.output_path
Outputs: eval.output_box, eval.progress_bar, None
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Lists all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
"""
checkpoints = []
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if os.path.isdir(os.path.join(save_dir, checkpoint)) and any(
os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES
):
checkpoints.append(checkpoint)
if finetuning_type in PEFT_METHODS:
return gr.Dropdown(value=[], choices=checkpoints, multiselect=True)
else:
return gr.Dropdown(value=None, choices=checkpoints, multiselect=False)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
"""
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.Dropdown(choices=datasets)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir
"""
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)

View File

@ -20,6 +20,29 @@ CSS = r"""
border-radius: 100vh !important; border-radius: 100vh !important;
} }
.thinking-summary {
padding: 8px !important;
}
.thinking-summary span {
border: 1px solid #e0e0e0 !important;
border-radius: 4px !important;
padding: 4px !important;
cursor: pointer !important;
font-size: 14px !important;
background: #333333 !important;
}
.thinking-container {
border-left: 2px solid #a6a6a6 !important;
padding-left: 10px !important;
margin: 4px 0 !important;
}
.thinking-container p {
color: #a6a6a6 !important;
}
.modal-box { .modal-box {
position: fixed !important; position: fixed !important;
top: 50%; top: 50%;

View File

@ -15,11 +15,10 @@
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import load_config from .common import create_ds_config, get_time, load_config
from .locales import LOCALES from .locales import LOCALES
from .manager import Manager from .manager import Manager
from .runner import Runner from .runner import Runner
from .utils import create_ds_config, get_time
if TYPE_CHECKING: if TYPE_CHECKING:
@ -27,6 +26,10 @@ if TYPE_CHECKING:
class Engine: class Engine:
r"""
A general engine to control the behaviors of Web UI.
"""
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.pure_chat = pure_chat self.pure_chat = pure_chat
@ -38,7 +41,7 @@ class Engine:
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r""" r"""
Gets the dict to update the components. Updates gradio components according to the (elem_id, properties) mapping.
""" """
output_dict: Dict["Component", "Component"] = {} output_dict: Dict["Component", "Component"] = {}
for elem_id, elem_attr in input_dict.items(): for elem_id, elem_attr in input_dict.items():
@ -48,9 +51,11 @@ class Engine:
return output_dict return output_dict
def resume(self): def resume(self):
user_config = load_config() if not self.demo_mode else {} r"""
Gets the initial value of gradio components and restores training status if necessary.
"""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en" lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
if not self.pure_chat: if not self.pure_chat:
@ -74,6 +79,9 @@ class Engine:
yield self._update_component({"eval.resume_btn": {"value": True}}) yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str): def change_lang(self, lang: str):
r"""
Updates the displayed language of gradio components.
"""
return { return {
elem: elem.__class__(**LOCALES[elem_name][lang]) elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter() for elem_name, elem in self.manager.get_elem_iter()

View File

@ -2786,6 +2786,20 @@ ALERTS = {
"ko": "모델이 언로드되었습니다.", "ko": "모델이 언로드되었습니다.",
"ja": "モデルがアンロードされました。", "ja": "モデルがアンロードされました。",
}, },
"info_thinking": {
"en": "🌀 Thinking...",
"ru": "🌀 Думаю...",
"zh": "🌀 思考中...",
"ko": "🌀 생각 중...",
"ja": "🌀 考えています...",
},
"info_thought": {
"en": "✅ Thought",
"ru": "✅ Думать закончено",
"zh": "✅ 思考完成",
"ko": "✅ 생각이 완료되었습니다",
"ja": "✅ 思考完了",
},
"info_exporting": { "info_exporting": {
"en": "Exporting model...", "en": "Exporting model...",
"ru": "Экспорт модели...", "ru": "Экспорт модели...",

View File

@ -20,6 +20,10 @@ if TYPE_CHECKING:
class Manager: class Manager:
r"""
A class to manage all the gradio components in Web UI.
"""
def __init__(self) -> None: def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {} self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {} self._elem_to_id: Dict["Component", str] = {}

View File

@ -24,9 +24,20 @@ from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config from .common import (
DEFAULT_CACHE_DIR,
DEFAULT_CONFIG_DIR,
abort_process,
gen_cmd,
get_save_dir,
load_args,
load_config,
load_eval_results,
save_args,
save_cmd,
)
from .control import get_trainer_info
from .locales import ALERTS, LOCALES from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available(): if is_gradio_available():
@ -40,6 +51,10 @@ if TYPE_CHECKING:
class Runner: class Runner:
r"""
A class to manage the running status of the trainers.
"""
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
@ -57,6 +72,9 @@ class Runner:
abort_process(self.trainer.pid) abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str: def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""
Validates the configuration.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset") dataset = get("train.dataset") if do_train else get("eval.dataset")
@ -98,6 +116,9 @@ class Runner:
return "" return ""
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
r"""
Cleans the cached memory and resets the runner.
"""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info) gr.Info(finish_info)
self.trainer = None self.trainer = None
@ -108,6 +129,9 @@ class Runner:
return finish_info return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the training arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
@ -268,6 +292,9 @@ class Runner:
return args return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the evaluation arguments.
"""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config() user_config = load_config()
@ -319,6 +346,9 @@ class Runner:
return args return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
r"""
Previews the training commands.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
@ -329,6 +359,9 @@ class Runner:
yield {output_box: gen_cmd(args)} yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]: def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
r"""
Starts the training process.
"""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval")) output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
@ -339,7 +372,7 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
os.makedirs(args["output_dir"], exist_ok=True) os.makedirs(args["output_dir"], exist_ok=True)
save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data)) save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._build_config_dict(data))
env = deepcopy(os.environ) env = deepcopy(os.environ)
env["LLAMABOARD_ENABLED"] = "1" env["LLAMABOARD_ENABLED"] = "1"
@ -350,7 +383,10 @@ class Runner:
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env) self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
yield from self.monitor() yield from self.monitor()
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds a dictionary containing the current training configuration.
"""
config_dict = {} config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items(): for elem, value in data.items():
@ -373,6 +409,9 @@ class Runner:
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self): def monitor(self):
r"""
Monitors the training progress and logs.
"""
self.aborted = False self.aborted = False
self.running = True self.running = True
@ -416,7 +455,7 @@ class Runner:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
else: else:
if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray(): if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
finish_info = get_eval_results(os.path.join(output_path, "all_results.json")) finish_info = load_eval_results(os.path.join(output_path, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
@ -427,6 +466,9 @@ class Runner:
yield return_dict yield return_dict
def save_args(self, data): def save_args(self, data):
r"""
Saves the training configuration to config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True) error = self._initialize(data, do_train=True, from_preview=True)
if error: if error:
@ -438,10 +480,13 @@ class Runner:
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path) save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
save_args(save_path, self._form_config_dict(data)) save_args(save_path, self._build_config_dict(data))
return {output_box: ALERTS["info_config_saved"][lang] + save_path} return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str): def load_args(self, lang: str, config_path: str):
r"""
Loads the training configuration from config path.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
if config_dict is None: if config_dict is None:
@ -455,6 +500,9 @@ class Runner:
return output_dict return output_dict
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
r"""
Restore the training status if output_dir exists.
"""
output_box = self.manager.get_elem_by_id("train.output_box") output_box = self.manager.get_elem_by_id("train.output_box")
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):

View File

@ -1,304 +0,0 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import signal
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import psutil
from transformers.trainer_utils import get_last_checkpoint
from yaml import safe_dump, safe_load
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.ploting import gen_loss_plot
from ..model import QuantizationMethod
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
from .locales import ALERTS
if is_gradio_available():
import gradio as gr
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
try:
children = psutil.Process(pid).children()
if children:
for child in children:
abort_process(child.pid)
os.kill(pid, signal.SIGABRT)
except Exception:
pass
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
"""
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value="none", interactive=False)
else:
return gr.Dropdown(interactive=True)
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Returns the available quantization bits.
"""
if quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
available_bits = ["none", "8", "4"]
elif quantization_method == QuantizationMethod.HQQ.value:
available_bits = ["none", "8", "6", "5", "4", "3", "2", "1"]
elif quantization_method == QuantizationMethod.EETQ.value:
available_bits = ["none", "8"]
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
"""
return [], TRAINING_STAGES[training_stage] == "pt"
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
try:
tools = json.loads(text)
if tools:
assert isinstance(tools, list)
for tool in tools:
if "name" not in tool:
raise NotImplementedError("Name not found.")
except NotImplementedError:
gr.Warning(ALERTS["err_tool_name"][lang])
except Exception:
gr.Warning(ALERTS["err_json_schema"][lang])
def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates arguments for previewing.
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
if isinstance(v, dict):
cmd_lines.append(f" --{k} {json.dumps(v, ensure_ascii=False)} ")
elif isinstance(v, list):
cmd_lines.append(f" --{k} {' '.join(map(str, v))} ")
else:
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves arguments to launch training.
"""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
safe_dump(clean_cmd(args), f)
return os.path.join(output_dir, TRAINING_ARGS)
def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return f"```json\n{result}\n```\n"
def get_time() -> str:
r"""
Gets current date and time.
"""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]:
r"""
Gets training infomation for monitor.
"""
running_log = ""
running_progress = gr.Slider(visible=False)
running_loss = None
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
if len(trainer_log) != 0:
latest_log = trainer_log[-1]
percentage = latest_log["percentage"]
label = "Running {:d}/{:d}: {} < {}".format(
latest_log["current_steps"],
latest_log["total_steps"],
latest_log["elapsed_time"],
latest_log["remaining_time"],
)
running_progress = gr.Slider(label=label, value=percentage, visible=True)
if do_train and is_matplotlib_available():
running_loss = gr.Plot(gen_loss_plot(trainer_log))
return running_log, running_progress, running_loss
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads saved arguments.
"""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]):
r"""
Saves arguments.
"""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for folder in os.listdir(save_dir):
output_dir = os.path.join(save_dir, folder)
if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None:
output_dirs.append(folder)
return gr.Dropdown(choices=output_dirs)
def create_ds_config() -> None:
r"""
Creates deepspeed config.
"""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1,
},
"bf16": {"enabled": "auto"},
}
offload_config = {
"device": "cpu",
"pin_memory": True,
}
ds_config["zero_optimization"] = {
"stage": 2,
"allgather_partitions": True,
"allgather_bucket_size": 5e8,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"contiguous_gradients": True,
"round_robin_gradients": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z2_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"] = {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
}
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)
ds_config["zero_optimization"]["offload_optimizer"] = offload_config
ds_config["zero_optimization"]["offload_param"] = offload_config
with open(os.path.join(DEFAULT_CACHE_DIR, "ds_z3_offload_config.json"), "w", encoding="utf-8") as f:
json.dump(ds_config, f, indent=2)