diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..51e9d59e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,160 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/README.md b/README.md
index cca8fdc3..be43a481 100644
--- a/README.md
+++ b/README.md
@@ -12,19 +12,23 @@
## Changelog
+[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
+
+[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
+
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
-[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
+[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
-[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
+[23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
-[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
+[23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
@@ -53,25 +57,22 @@
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
+| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
-> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
-> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc.
+- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
+- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches
-- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- - Full-parameter tuning
- - Partial-parameter tuning
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
-- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- - Full-parameter tuning
- - Partial-parameter tuning
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
-- [RLHF](https://arxiv.org/abs/2203.02155)
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
+| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
+| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
+| PPO Training | | | :white_check_mark: | :white_check_mark: |
+| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
+
+- Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
@@ -88,7 +89,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- - [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -103,7 +103,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
-- For reward modelling:
+- For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -139,7 +139,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional)
```bash
-git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
@@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
Currently the web UI only supports training on **a single GPU**.
-### (Continually) Pre-Training
+### Pre-Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -207,9 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
-Remember to specify `--lora_target W_pack` if you are using Baichuan models.
-
-### Reward Model Training
+### Reward Modeling
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
- --per_device_train_batch_size 4 \
+ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
@@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
-### PPO Training (RLHF)
+### PPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -257,14 +254,40 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss
```
+### DPO Training
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage dpo \
+ --model_name_or_path path_to_your_model \
+ --do_train \
+ --dataset comparison_gpt4_en \
+ --template default \
+ --finetuning_type lora \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_dpo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
### Distributed Training
+#### Use Huggingface Accelerate
+
```bash
accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
```
-Example configuration for full-tuning with DeepSpeed ZeRO-2
+Example config.yaml for training with DeepSpeed ZeRO-2
```yaml
compute_environment: LOCAL_MACHINE
@@ -292,6 +315,44 @@ use_cpu: false
+#### Use DeepSpeed
+
+```bash
+deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
+ --deepspeed ds_config.json \
+ ... # arguments (same as above)
+```
+
+Example ds_config.json for training with DeepSpeed ZeRO-2
+
+```json
+{
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "overlap_comm": false,
+ "contiguous_gradients": true
+ }
+}
+```
+
+
+
### Evaluation (BLEU and ROUGE_CHINESE)
```bash
@@ -390,6 +451,8 @@ Please follow the model licenses to use the corresponding model weights:
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
+- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
+- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## Citation
diff --git a/README_zh.md b/README_zh.md
index d5eca99d..2a84e697 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -12,31 +12,35 @@
## 更新日志
-[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。
+[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
-[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
+[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
+
+[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
+
+[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
-[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
+[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
-[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
+[23/07/18] 我们开发了支持训练和测试的**一体化浏览器界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
-[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
+[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
-[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
+[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
-[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
+[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
-[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
+[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。
-[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
+[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。
@@ -53,42 +57,38 @@
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
+| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
-> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。
-> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。
+- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
+- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
-## 微调方法
+## 训练方法
-- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- - 全参数微调
- - 部分参数微调
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
-- [指令监督微调](https://arxiv.org/abs/2109.01652)
- - 全参数微调
- - 部分参数微调
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
-- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
- - [LoRA](https://arxiv.org/abs/2106.09685)
- - [QLoRA](https://arxiv.org/abs/2305.14314)
+| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
+| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
+| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
+| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
+| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
+| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
+
+- 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
## 数据集
-- 用于二次预训练:
+- 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
-- 用于指令监督微调:
+- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- - [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -103,7 +103,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
-- 用于奖励模型训练:
+- 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -139,7 +139,6 @@ huggingface-cli login
### 环境搭建(可跳过)
```bash
-git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
@@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
目前网页 UI 仅支持**单卡训练**。
-### 二次预训练
+### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -207,8 +206,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
-使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
-
### 奖励模型训练
```bash
@@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
- --per_device_train_batch_size 4 \
+ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
@@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
-### RLHF 训练
+### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -257,8 +254,34 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss
```
+### DPO 训练
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
+ --stage dpo \
+ --model_name_or_path path_to_your_model \
+ --do_train \
+ --dataset comparison_gpt4_zh \
+ --template default \
+ --finetuning_type lora \
+ --resume_lora_training False \
+ --checkpoint_dir path_to_sft_checkpoint \
+ --output_dir path_to_dpo_checkpoint \
+ --per_device_train_batch_size 2 \
+ --gradient_accumulation_steps 4 \
+ --lr_scheduler_type cosine \
+ --logging_steps 10 \
+ --save_steps 1000 \
+ --learning_rate 1e-5 \
+ --num_train_epochs 1.0 \
+ --plot_loss \
+ --fp16
+```
+
### 多 GPU 分布式训练
+#### 使用 Huggingface Accelerate
+
```bash
accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
@@ -292,7 +315,45 @@ use_cpu: false
-### 指标评估(BLEU分数和汉语ROUGE分数)
+#### 使用 DeepSpeed
+
+```bash
+deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
+ --deepspeed ds_config.json \
+ ... # 参数同上
+```
+
+使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例
+
+```json
+{
+ "train_micro_batch_size_per_gpu": "auto",
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "zero_allow_untested_optimizer": true,
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "initial_scale_power": 16,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 5e8,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 5e8,
+ "overlap_comm": false,
+ "contiguous_gradients": true
+ }
+}
+```
+
+
+
+### 指标评估(BLEU 分数和汉语 ROUGE 分数)
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -309,7 +370,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
-我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128` 参数。
+我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
### 模型预测
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
index 741cef3c..da80faae 100644
Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ
diff --git a/data/dataset_info.json b/data/dataset_info.json
index 3a3b4e76..3eaf920e 100644
--- a/data/dataset_info.json
+++ b/data/dataset_info.json
@@ -49,26 +49,6 @@
"history": "history"
}
},
- "refgpt_zh_p1": {
- "file_name": "refgpt_zh_50k_p1.json",
- "file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
- "columns": {
- "prompt": "instruction",
- "query": "input",
- "response": "output",
- "history": "history"
- }
- },
- "refgpt_zh_p2": {
- "file_name": "refgpt_zh_50k_p2.json",
- "file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
- "columns": {
- "prompt": "instruction",
- "query": "input",
- "response": "output",
- "history": "history"
- }
- },
"lima": {
"file_name": "lima.json",
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
diff --git a/data/refgpt_zh_50k_p1.json.REMOVED.git-id b/data/refgpt_zh_50k_p1.json.REMOVED.git-id
deleted file mode 100644
index 3e8a9e41..00000000
--- a/data/refgpt_zh_50k_p1.json.REMOVED.git-id
+++ /dev/null
@@ -1 +0,0 @@
-f967a4f6d04a11308a15524aa9a846a19a8d1e83
\ No newline at end of file
diff --git a/data/refgpt_zh_50k_p2.json.REMOVED.git-id b/data/refgpt_zh_50k_p2.json.REMOVED.git-id
deleted file mode 100644
index a6525b27..00000000
--- a/data/refgpt_zh_50k_p2.json.REMOVED.git-id
+++ /dev/null
@@ -1 +0,0 @@
-0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 9b74b21d..fb5fa72f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
-trl>=0.4.7
+trl>=0.5.0
scipy
sentencepiece
tiktoken
diff --git a/src/api_demo.py b/src/api_demo.py
index c0ca9760..777f9dcf 100644
--- a/src/api_demo.py
+++ b/src/api_demo.py
@@ -7,7 +7,7 @@ def main():
chat_model = ChatModel()
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
- # Visit http://localhost:8000/docs for document.
+ print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__":
diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py
index e647b92b..bbc1420b 100644
--- a/src/llmtuner/__init__.py
+++ b/src/llmtuner/__init__.py
@@ -6,4 +6,4 @@ from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
-__version__ = "0.1.5"
+__version__ = "0.1.6"
diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py
index 4fc5fc43..41a7fe9a 100644
--- a/src/llmtuner/api/app.py
+++ b/src/llmtuner/api/app.py
@@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
- if request.messages[-1].role != Role.USER:
+ if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
- query = request.messages[-1].content
+ query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
- prefix = prev_messages.pop(0).content
+ system = prev_messages.pop(0).content
else:
- prefix = None
+ system = None
history = []
if len(prev_messages) % 2 == 0:
@@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream:
- generate = predict(query, history, prefix, request)
+ generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
- query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
+ query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
)
usage = ChatCompletionResponseUsage(
@@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
- async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
+ async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role=Role.ASSISTANT),
@@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat(
- query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
+ query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0:
continue
diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py
index 79d3b92d..0d22bb5a 100644
--- a/src/llmtuner/chat/stream_chat.py
+++ b/src/llmtuner/chat/stream_chat.py
@@ -1,10 +1,9 @@
import torch
-from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
-from transformers import PreTrainedModel, TextIteratorStreamer
+from transformers import TextIteratorStreamer
-from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
+from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@@ -15,23 +14,21 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
- self.model = self.model.eval() # change to eval mode
+ self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
- self.source_prefix = data_args.source_prefix
- self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
- self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
+ self.system_prompt = data_args.system_prompt
def process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None,
+ system: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]:
- prefix = prefix or self.source_prefix
+ system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn(
- tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
+ tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
)
input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
@@ -52,8 +49,9 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
- logits_processor=get_logits_processor(),
- stopping_criteria=get_stopping_criteria(self.stop_ids)
+ eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
+ pad_token_id=self.tokenizer.pad_token_id,
+ logits_processor=get_logits_processor()
))
if max_length:
@@ -71,10 +69,10 @@ class ChatModel:
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None,
+ system: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]:
- gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
+ gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@@ -86,10 +84,10 @@ class ChatModel:
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None,
+ system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
- gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
+ gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py
index 90a4212f..08c35f27 100644
--- a/src/llmtuner/dsets/loader.py
+++ b/src/llmtuner/dsets/loader.py
@@ -1,48 +1,25 @@
import os
-import hashlib
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, List, Union
-from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
+from datasets import concatenate_datasets, interleave_datasets, load_dataset
+from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
- from datasets import Dataset
+ from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
-EXT2TYPE = {
- "csv": "csv",
- "json": "json",
- "jsonl": "json",
- "txt": "text"
-}
-
-
-def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
- if file_sha1 is None:
- logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
- return
-
- if len(data_files) != 1:
- logger.warning("Checksum failed: too many files.")
- return
-
- with open(data_files[0], "rb") as f:
- sha1 = hashlib.sha1(f.read()).hexdigest()
- if sha1 != file_sha1:
- logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
-
-
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
-) -> "Dataset":
+) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
- all_datasets: List["Dataset"] = [] # support multiple datasets
+ all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
@@ -92,12 +69,11 @@ def get_dataset(
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
- if dataset_attr.source_prefix: # add prefix
- features = None
+ if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
- features = dataset.features
- features["prefix"] = Value(dtype="string", id=None)
- dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
+ dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
+ else:
+ dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset)
diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py
index d2150dbc..1fc146f8 100644
--- a/src/llmtuner/dsets/preprocess.py
+++ b/src/llmtuner/dsets/preprocess.py
@@ -1,24 +1,25 @@
-from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
+import tiktoken
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
- from datasets import Dataset
+ from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
- dataset: "Dataset",
+ dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
-) -> "Dataset":
- column_names = list(dataset.column_names)
+) -> Union["Dataset", "IterableDataset"]:
+ column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
@@ -26,15 +27,16 @@ def preprocess_dataset(
query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None
- prefix = examples["prefix"][i] if "prefix" in examples else None
- yield query, response, history, prefix
+ system = examples["system"][i] if "system" in examples else None
+ yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without )
- if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen)
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
+
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@@ -46,7 +48,6 @@ def preprocess_dataset(
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
- result["labels"] = result["input_ids"].copy()
return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
@@ -55,10 +56,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length
- for query, response, history, prefix in construct_example(examples):
+ for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
- for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix):
+ for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
@@ -77,11 +78,11 @@ def preprocess_dataset(
return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
- # build inputs with format ` X` and labels with format ` Y`
+ # build inputs with format ` X` and labels with format `Y `
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
- for query, response, history, prefix in construct_example(examples):
- source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix)
+ for query, response, history, system in construct_example(examples):
+ source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
@@ -95,24 +96,22 @@ def preprocess_dataset(
return model_inputs
def preprocess_pairwise_dataset(examples):
- # build input pairs with format ` X Y1 ` and ` X Y2 `
- model_inputs = {"accept_ids": [], "reject_ids": []}
- for query, response, history, prefix in construct_example(examples):
- source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix)
- source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix)
+ # build input pairs with format ` X`, `Y1 ` and `Y2 `
+ model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
+ for query, response, history, system in construct_example(examples):
+ prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
+ _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
- if len(source_ids) > data_args.max_source_length:
- source_ids = source_ids[:data_args.max_source_length]
- if len(accept_ids) > data_args.max_target_length:
- accept_ids = accept_ids[:data_args.max_target_length - 1]
- if len(reject_ids) > data_args.max_target_length:
- reject_ids = reject_ids[:data_args.max_target_length - 1]
+ if len(prompt_ids) > data_args.max_source_length:
+ prompt_ids = prompt_ids[:data_args.max_source_length]
+ if len(chosen_ids) > data_args.max_target_length:
+ chosen_ids = chosen_ids[:data_args.max_target_length]
+ if len(rejected_ids) > data_args.max_target_length:
+ rejected_ids = rejected_ids[:data_args.max_target_length]
- accept_ids = source_ids + accept_ids
- reject_ids = source_ids + reject_ids
-
- model_inputs["accept_ids"].append(accept_ids)
- model_inputs["reject_ids"].append(reject_ids)
+ model_inputs["prompt_ids"].append(prompt_ids)
+ model_inputs["chosen_ids"].append(chosen_ids)
+ model_inputs["rejected_ids"].append(rejected_ids)
return model_inputs
def print_supervised_dataset_example(example):
@@ -124,10 +123,12 @@ def preprocess_dataset(
], skip_special_tokens=False)))
def print_pairwise_dataset_example(example):
- print("accept_ids:\n{}".format(example["accept_ids"]))
- print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
- print("reject_ids:\n{}".format(example["reject_ids"]))
- print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
+ print("prompt_ids:\n{}".format(example["prompt_ids"]))
+ print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
+ print("chosen_ids:\n{}".format(example["chosen_ids"]))
+ print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
+ print("rejected_ids:\n{}".format(example["rejected_ids"]))
+ print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
@@ -166,8 +167,5 @@ def preprocess_dataset(
**kwargs
)
- if data_args.streaming:
- dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
-
print_function(next(iter(dataset)))
return dataset
diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/dsets/utils.py
index 31c48222..bf337014 100644
--- a/src/llmtuner/dsets/utils.py
+++ b/src/llmtuner/dsets/utils.py
@@ -1,15 +1,59 @@
-from typing import TYPE_CHECKING, Dict
+import hashlib
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
- from datasets import Dataset
+ from datasets import Dataset, IterableDataset
+ from transformers import TrainingArguments
+ from llmtuner.hparams import DataArguments
-def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
- if do_train:
- if dev_ratio > 1e-6: # Split the dataset
- dataset = dataset.train_test_split(test_size=dev_ratio)
- return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
+logger = get_logger(__name__)
+
+
+EXT2TYPE = {
+ "csv": "csv",
+ "json": "json",
+ "jsonl": "json",
+ "txt": "text"
+}
+
+
+def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
+ if file_sha1 is None:
+ logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
+ return
+
+ if len(data_files) != 1:
+ logger.warning("Checksum failed: too many files.")
+ return
+
+ with open(data_files[0], "rb") as f:
+ sha1 = hashlib.sha1(f.read()).hexdigest()
+ if sha1 != file_sha1:
+ logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
+
+
+def split_dataset(
+ dataset: Union["Dataset", "IterableDataset"],
+ data_args: "DataArguments",
+ training_args: "TrainingArguments"
+) -> Dict[str, "Dataset"]:
+ if training_args.do_train:
+ if data_args.val_size > 1e-6: # Split the dataset
+ if data_args.streaming:
+ val_set = dataset.take(int(data_args.val_size))
+ train_set = dataset.skip(int(data_args.val_size))
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
+ return {"train_dataset": train_set, "eval_dataset": val_set}
+ else:
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
+ dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
+ return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
+ if data_args.streaming:
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}
diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py
index d325b0a8..61deae25 100644
--- a/src/llmtuner/extras/callbacks.py
+++ b/src/llmtuner/extras/callbacks.py
@@ -7,10 +7,16 @@ from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
+from llmtuner.extras.constants import LOG_FILE_NAME
+from llmtuner.extras.logging import get_logger
+
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
+logger = get_logger(__name__)
+
+
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
@@ -38,6 +44,9 @@ class LogCallback(TrainerCallback):
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
+ if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
+ logger.warning("Previous log file in this folder will be deleted.")
+ os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py
index 6f6dbdd7..cd22943f 100644
--- a/src/llmtuner/extras/constants.py
+++ b/src/llmtuner/extras/constants.py
@@ -1,13 +1,23 @@
IGNORE_INDEX = -100
+LOG_FILE_NAME = "trainer_log.jsonl"
+
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
-LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
+LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"]
+STAGES = [
+ "SFT",
+ "Reward Modeling",
+ "PPO",
+ "DPO",
+ "Pre-Training"
+]
+
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
@@ -19,6 +29,10 @@ SUPPORTED_MODELS = {
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
+ "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
+ "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
+ "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
+ "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
@@ -35,16 +49,30 @@ SUPPORTED_MODELS = {
"InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
- "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat"
+ "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
+ "XVERSE-13B": "xverse/XVERSE-13B",
+ "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
}
DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
+ "ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value",
"BLOOMZ": "query_key_value",
"Falcon": "query_key_value",
"Baichuan": "W_pack",
"InternLM": "q_proj,v_proj",
- "Qwen": "c_attn"
+ "Qwen": "c_attn",
+ "XVERSE": "q_proj,v_proj",
+ "ChatGLM2": "query_key_value"
+}
+
+DEFAULT_TEMPLATE = {
+ "LLaMA2": "llama2",
+ "ChineseLLaMA2": "llama2_zh",
+ "Baichuan": "baichuan",
+ "InternLM": "intern",
+ "Qwen": "chatml",
+ "ChatGLM2": "chatglm2"
}
diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py
index 0b1a68f6..d6f185e6 100644
--- a/src/llmtuner/extras/logging.py
+++ b/src/llmtuner/extras/logging.py
@@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
super().__init__()
self.log = ""
+ def reset(self):
+ self.log = ""
+
def emit(self, record):
if record.name == "httpx":
return
diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py
index e1fbb156..b57b1c8f 100644
--- a/src/llmtuner/extras/misc.py
+++ b/src/llmtuner/extras/misc.py
@@ -1,7 +1,6 @@
import torch
from typing import TYPE_CHECKING, List, Optional, Tuple
-
-from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
+from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from llmtuner.extras.constants import LAYERNORM_NAMES
@@ -29,37 +28,12 @@ class AverageMeter:
self.avg = self.sum / self.count
-class InvalidScoreLogitsProcessor(LogitsProcessor):
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
- if torch.isnan(scores).any() or torch.isinf(scores).any():
- scores.zero_()
- scores[..., 0] = 1.0
- return scores
-
-
def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList()
- logits_processor.append(InvalidScoreLogitsProcessor())
+ logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
-class StopWordsCriteria(StoppingCriteria):
-
- def __init__(self, stop_ids: List[int]) -> None:
- super().__init__()
- self.stop_ids = stop_ids
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
- return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
-
-
-def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
- stopping_criteria = StoppingCriteriaList()
- stopping_criteria.append(StopWordsCriteria(stop_ids))
- return stopping_criteria
-
-
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
@@ -91,7 +65,6 @@ def prepare_model_for_training(
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
-
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32)
@@ -108,9 +81,6 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name):
- if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
- model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
-
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
@@ -138,6 +108,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
+ if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
+ return model
+
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py
index 5b00af03..5d5a03fb 100644
--- a/src/llmtuner/extras/template.py
+++ b/src/llmtuner/extras/template.py
@@ -1,15 +1,22 @@
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+import tiktoken
from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
+logger = get_logger(__name__)
+
+
@dataclass
class Template:
prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]]
+ system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
@@ -20,18 +27,18 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None
+ system: Optional[str] = None
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
"""
- prefix, history = self._format(query, resp, history, prefix)
- encoded_pairs = self._encode(tokenizer, prefix, history)
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
- prompt_ids = prompt_ids + encoded_pairs[-1][0]
- return prompt_ids, encoded_pairs[-1][1]
+ prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
+ return prompt_ids, answer_ids
def encode_multiturn(
self,
@@ -39,13 +46,13 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None
+ system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
- prefix, history = self._format(query, resp, history, prefix)
- encoded_pairs = self._encode(tokenizer, prefix, history)
+ system, history = self._format(query, resp, history, system)
+ encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format(
@@ -53,26 +60,29 @@ class Template:
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
- prefix: Optional[str] = None
- ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
+ system: Optional[str] = None
+ ) -> Tuple[str, List[Tuple[str, str]]]:
r"""
- Aligns inputs to a special format.
+ Aligns inputs to the standard format.
"""
- prefix = [prefix] if prefix else self.prefix # use prefix if provided
+ system = system or self.system # use system if provided
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
- return prefix, history
+ return system, history
def _get_special_ids(
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
- if tokenizer.bos_token_id:
+ if (
+ tokenizer.bos_token_id is not None
+ and getattr(tokenizer, "add_bos_token", True)
+ ): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id]
else:
bos_ids = [] # bos token is optional
- if tokenizer.eos_token_id:
+ if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id]
else:
raise ValueError("EOS token is required.")
@@ -82,35 +92,44 @@ class Template:
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- prefix: List[Union[str, Dict[str, str]]],
+ system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + sep + query resp + eos
+ Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
- prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
+ prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
+ if len(prefix_ids) != 0: # has prefix
+ prefix_ids = bos_ids + prefix_ids + sep_ids
+ else:
+ prefix_ids = bos_ids
else:
- prefix_ids = sep_ids
- query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
+ prefix_ids = sep_ids + bos_ids
+
+ query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
- encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids))
+ encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
- query: Optional[str] = ""
+ system: Optional[str] = None,
+ query: Optional[str] = None,
+ idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
"""
- if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen)
+ if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
@@ -118,12 +137,15 @@ class Template:
token_ids = []
for elem in context:
if isinstance(elem, str):
- elem = elem.replace("{{query}}", query, 1)
+ elem = elem.replace("{{system}}", system, 1) if system is not None else elem
+ elem = elem.replace("{{query}}", query, 1) if query is not None else elem
+ elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
+
return token_ids
@@ -133,18 +155,19 @@ class Llama2Template(Template):
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
- prefix: List[Union[str, Dict[str, str]]],
+ system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
+ Turn 0: bos + prefix + query resp + eos
+ Turn t: bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = []
- assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
for turn_idx, (query, resp) in enumerate(history):
- if turn_idx == 0: # llama2 template has not sep_ids
- query = prefix[0] + query
+ if turn_idx == 0: # llama2 template has no sep_ids
+ query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
@@ -158,14 +181,16 @@ def register_template(
name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
+ system: str,
sep: List[Union[str, Dict[str, str]]],
- stop_words: List[str],
- use_history: bool
+ stop_words: Optional[List[str]] = [],
+ use_history: Optional[bool] = True
) -> None:
- template_class = Llama2Template if name == "llama2" else Template
+ template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
prefix=prefix,
prompt=prompt,
+ system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history
@@ -179,13 +204,27 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
- if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
- tokenizer.eos_token = template.stop_words[0]
+ additional_special_tokens = template.stop_words
+ if len(template.stop_words): # inplace method
+ if tokenizer.eos_token_id is not None:
+ additional_special_tokens.append(tokenizer.eos_token)
- if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
- tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
+ additional_special_tokens.pop(0)
+ logger.info("Replace eos token: {}".format(tokenizer.eos_token))
- tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
+ if tokenizer.eos_token_id is None:
+ tokenizer.eos_token = "<|endoftext|>"
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
+
+ if tokenizer.pad_token_id is None:
+ if tokenizer.unk_token_id is not None:
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
+
+ tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
return template
@@ -198,8 +237,8 @@ register_template(
prompt=[
"{{query}}"
],
+ system="",
sep=[],
- stop_words=[],
use_history=False
)
@@ -210,17 +249,18 @@ Default template.
register_template(
name="default",
prefix=[
- "A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ "{{system}}"
],
prompt=[
"Human: {{query}}\nAssistant: "
],
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
sep=[
"\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -232,21 +272,39 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
register_template(
name="llama2",
prefix=[
- "<>\nYou are a helpful, respectful and honest assistant. "
+ "<>\n{{system}}\n<>\n\n"
+ ],
+ prompt=[
+ "[INST] {{query}} [/INST] "
+ ],
+ system=(
+ "You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
- "If you don't know the answer to a question, please don't share false information.\n<>\n\n"
+ "If you don't know the answer to a question, please don't share false information."
+ ),
+ sep=[]
+)
+
+
+r"""
+Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
+ https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
+"""
+register_template(
+ name="llama2_zh",
+ prefix=[
+ "<>\n{{system}}\n<>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
- sep=[],
- stop_words=[],
- use_history=True
+ system="You are a helpful assistant. 你是一个乐于助人的助手。",
+ sep=[]
)
@@ -257,17 +315,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
register_template(
name="alpaca",
prefix=[
- "Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request."
+ "{{system}}"
],
prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n"
],
+ system=(
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request."
+ ),
sep=[
"\n\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -278,15 +337,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
register_template(
name="vicuna",
prefix=[
- "A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ "{{system}}"
],
prompt=[
"USER: {{query}} ASSISTANT: "
],
- sep=[],
- stop_words=[],
- use_history=True
+ system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ sep=[]
)
@@ -295,15 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
- prefix=[],
+ prefix=[
+ "{{system}}"
+ ],
prompt=[
"Human: {{query}}\n\nBelle: "
],
+ system="",
sep=[
"\n\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -312,15 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
- prefix=[],
+ prefix=[
+ "{{system}}"
+ ],
prompt=[
"User: {{query}}\nBot: "
],
+ system="",
sep=[
"\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -329,15 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
- prefix=[],
+ prefix=[
+ "{{system}}"
+ ],
prompt=[
"Human: {{query}}\nAssistant: "
],
+ system="",
sep=[
"\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -346,18 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
- prefix=[],
+ prefix=[
+ "{{system}}"
+ ],
prompt=[
{"token": ""},
":{{query}}\n",
{"token": ""},
":"
],
+ system="",
sep=[
"\n"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -367,17 +431,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
register_template(
name="aquila",
prefix=[
- "A chat between a curious human and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ "{{system}}"
],
prompt=[
"Human: {{query}}###Assistant: "
],
+ system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
sep=[
"###"
- ],
- stop_words=[],
- use_history=True
+ ]
)
@@ -386,19 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
- prefix=[],
+ prefix=[
+ "{{system}}"
+ ],
prompt=[
"<|User|>:{{query}}",
{"token": ""},
"\n<|Bot|>:"
],
+ system="",
sep=[
"\n"
],
stop_words=[
+ "", # internlm cannot replace eos token
""
- ],
- use_history=True
+ ]
)
@@ -407,15 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
- prefix=[],
- prompt=[
- {"token": ""},
- "{{query}}",
- {"token": ""}
+ prefix=[
+ "{{system}}",
+ {"token": ""} # user token
],
+ prompt=[
+ "{{query}}",
+ {"token": ""} # assistant token
+ ],
+ system="",
sep=[],
- stop_words=[],
- use_history=True
+ stop_words=[
+ "" # user token
+ ]
)
@@ -427,7 +499,8 @@ register_template(
name="starchat",
prefix=[
{"token": "<|system|>"},
- "\n"
+ "\n{{system}}",
+ {"token": "<|end|>"}
],
prompt=[
{"token": "<|user|>"},
@@ -436,13 +509,13 @@ register_template(
"\n",
{"token": "<|assistant|>"}
],
+ system="",
sep=[
"\n"
],
stop_words=[
"<|end|>"
- ],
- use_history=True
+ ]
)
@@ -453,7 +526,8 @@ register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
- "system\nYou are a helpful assistant."
+ "system\n{{system}}",
+ {"token": "<|im_end|>"}
],
prompt=[
{"token": "<|im_start|>"},
@@ -463,11 +537,31 @@ register_template(
{"token": "<|im_start|>"},
"assistant\n"
],
+ system="You are a helpful assistant.",
sep=[
"\n"
],
stop_words=[
"<|im_end|>"
- ],
- use_history=True
+ ]
+)
+
+
+r"""
+Supports: https://huggingface.co/THUDM/chatglm2-6b
+"""
+register_template(
+ name="chatglm2",
+ prefix=[
+ {"token": "[gMASK]"},
+ {"token": "sop"},
+ "{{system}}"
+ ],
+ prompt=[
+ "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
+ ],
+ system="",
+ sep=[
+ "\n\n"
+ ]
)
diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py
index 60945b60..374d03c6 100644
--- a/src/llmtuner/hparams/data_args.py
+++ b/src/llmtuner/hparams/data_args.py
@@ -10,7 +10,7 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
- source_prefix: Optional[str] = None
+ system_prompt: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
@@ -24,7 +24,7 @@ class DatasetAttr:
@dataclass
class DataArguments:
- """
+ r"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template: str = field(
@@ -86,13 +86,13 @@ class DataArguments:
default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
)
- source_prefix: Optional[str] = field(
+ system_prompt: Optional[str] = field(
default=None,
- metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
+ metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
)
- dev_ratio: Optional[float] = field(
+ val_size: Optional[float] = field(
default=0,
- metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
def init_for_training(self): # support mixing multiple datasets
@@ -100,12 +100,9 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
- if self.source_prefix is not None:
- prefix_list = self.source_prefix.split("|")
- prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
- assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
- else:
- prefix_list = [None] * len(dataset_names)
+ prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
+ prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
+ assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
@@ -126,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
- dataset_attr.source_prefix = prefix_list[i]
-
if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
+ dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)
diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py
index 277602ae..d7d651dd 100644
--- a/src/llmtuner/hparams/finetuning_args.py
+++ b/src/llmtuner/hparams/finetuning_args.py
@@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
- """
+ r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
@@ -14,7 +14,7 @@ class FinetuningArguments:
)
num_hidden_layers: Optional[int] = field(
default=32,
- metadata={"help": "Number of decoder blocks in the model. \
+ metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \
@@ -25,16 +25,16 @@ class FinetuningArguments:
)
num_layer_trainable: Optional[int] = field(
default=3,
- metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
+ metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp",
- metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
- LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
+ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
+ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \
- InternLM, XVERSE choices: the same as LLaMA."}
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
lora_rank: Optional[int] = field(
default=8,
@@ -51,11 +51,19 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
- LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
+ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
- InternLM, XVERSE choices: the same as LLaMA."}
+ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
+ )
+ resume_lora_training: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
+ )
+ dpo_beta: Optional[float] = field(
+ default=0.1,
+ metadata={"help": "The beta parameter for the DPO loss."}
)
def __post_init__(self):
@@ -72,14 +80,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
- """Saves the content of this instance in JSON format inside `json_path`."""
+ r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string)
@classmethod
def load_from_json(cls, json_path: str):
- """Creates an instance from the content of `json_path`."""
+ r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
diff --git a/src/llmtuner/hparams/general_args.py b/src/llmtuner/hparams/general_args.py
index 397d3019..c0c1a0de 100644
--- a/src/llmtuner/hparams/general_args.py
+++ b/src/llmtuner/hparams/general_args.py
@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
- """
+ r"""
Arguments pertaining to which stage we are going to perform.
"""
- stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
+ stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py
index e25ff4b9..f8b935fb 100644
--- a/src/llmtuner/hparams/generating_args.py
+++ b/src/llmtuner/hparams/generating_args.py
@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass
class GeneratingArguments:
- """
+ r"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py
index 253d9839..dc515f51 100644
--- a/src/llmtuner/hparams/model_args.py
+++ b/src/llmtuner/hparams/model_args.py
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass
class ModelArguments:
- """
+ r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
"""
model_name_or_path: str = field(
@@ -43,9 +43,9 @@ class ModelArguments:
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
- compute_dtype: Optional[torch.dtype] = field(
+ rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
- metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
+ metadata={"help": "Adopt scaled rotary positional embeddings."}
)
checkpoint_dir: Optional[str] = field(
default=None,
@@ -55,18 +55,33 @@ class ModelArguments:
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
- resume_lora_training: Optional[bool] = field(
- default=True,
- metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
- )
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
+ hf_auth_token: Optional[str] = field(
+ default=None,
+ metadata={"help": "Auth token to log in with Hugging Face Hub."}
+ )
+ compute_dtype: Optional[torch.dtype] = field(
+ default=None,
+ metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
+ )
+ model_max_length: Optional[int] = field(
+ default=None,
+ metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
+ )
def __post_init__(self):
+ if self.compute_dtype is not None or self.model_max_length is not None:
+ raise ValueError("These arguments cannot be specified.")
+
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
+
+ if self.use_auth_token == True and self.hf_auth_token is not None:
+ from huggingface_hub.hf_api import HfFolder # lazy load
+ HfFolder.save_token(self.hf_auth_token)
diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py
index 4afad13a..5db56876 100644
--- a/src/llmtuner/tuner/core/adapter.py
+++ b/src/llmtuner/tuner/core/adapter.py
@@ -39,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
- if finetuning_args.finetuning_type == "full":
+ if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
model = model.float()
@@ -65,7 +65,7 @@ def init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
- if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
+ if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py
index c06eabfa..4bf767a6 100644
--- a/src/llmtuner/tuner/core/loader.py
+++ b/src/llmtuner/tuner/core/loader.py
@@ -1,5 +1,7 @@
import os
+import math
import torch
+from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
@@ -34,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
-require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
+require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer(
@@ -52,9 +54,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
- assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
- "RM and PPO training can only be performed with the LoRA method."
-
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
@@ -69,15 +68,58 @@ def load_model_and_tokenizer(
**config_kwargs
)
- if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
+ if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
- is_mergeable = True
+
+ if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
+ if model_args.compute_dtype == torch.bfloat16:
+ setattr(config, "bf16", True)
+ else:
+ setattr(config, "fp16", True)
+
+ # Set RoPE scaling
+ if model_args.rope_scaling is not None:
+ if hasattr(config, "use_dynamic_ntk"): # for Qwen models
+ if is_trainable:
+ logger.warning("Qwen model does not support RoPE scaling in training.")
+ else:
+ setattr(config, "use_dynamic_ntk", True)
+ setattr(config, "use_logn_attn", True)
+ logger.info("Using dynamic NTK scaling.")
+
+ elif hasattr(config, "rope_scaling"): # for LLaMA models
+ require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
+
+ if is_trainable:
+ if model_args.rope_scaling == "dynamic":
+ logger.warning(
+ "Dynamic NTK may not work well with fine-tuning. "
+ "See: https://github.com/huggingface/transformers/pull/24653"
+ )
+
+ current_max_length = getattr(config, "max_position_embeddings", None)
+ if current_max_length and model_args.model_max_length > current_max_length:
+ scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
+ else:
+ logger.warning("Input length is smaller than max length. Consider increase input length.")
+ scaling_factor = 1.0
+ else:
+ scaling_factor = 2.0
+
+ setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
+ logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
+ model_args.rope_scaling, scaling_factor
+ ))
+
+ else:
+ logger.warning("Current model does not support RoPE scaling.")
# Quantization configurations (using bitsandbytes library).
+ is_mergeable = True
if model_args.quantization_bit is not None:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
@@ -95,10 +137,10 @@ def load_model_and_tokenizer(
)
is_mergeable = False
- config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
+ config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
- # Load and prepare pretrained models (without valuehead).
+ # Load and prepare pre-trained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
@@ -107,6 +149,14 @@ def load_model_and_tokenizer(
**config_kwargs
)
+ # Disable custom generate method (for Qwen)
+ if "GenerationMixin" not in str(model.generate.__func__):
+ model.generate = MethodType(PreTrainedModel.generate, model)
+
+ # Fix LM head (for ChatGLM2)
+ if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
+ setattr(model, "lm_head", model.transformer.output_layer)
+
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
@@ -119,10 +169,10 @@ def load_model_and_tokenizer(
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
- if stage == "rm" or stage == "ppo": # add value head
- model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
+ # Prepare model with valuehead for RLHF
+ if stage == "rm" or stage == "ppo":
+ model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
-
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
@@ -132,15 +182,15 @@ def load_model_and_tokenizer(
})
if stage == "ppo": # load reward model
- assert is_trainable, "PPO stage cannot be performed at evaluation."
- assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
+ # Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
- model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
+ infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
+ model = model.to(infer_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py
index d872afcc..e039513d 100644
--- a/src/llmtuner/tuner/core/parser.py
+++ b/src/llmtuner/tuner/core/parser.py
@@ -19,7 +19,7 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
-def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
+def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
@@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args(
args: Optional[Dict[str, Any]] = None
-) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments,
+ GeneralArguments
+]:
parser = HfArgumentParser((
- ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments,
+ GeneralArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
-) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
parser = HfArgumentParser((
- ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
-) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
- model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ Seq2SeqTrainingArguments,
+ FinetuningArguments,
+ GeneratingArguments,
+ GeneralArguments
+]:
+ model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging
if training_args.should_log:
@@ -67,33 +94,42 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
- assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
- "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
+ if general_args.stage != "sft" and training_args.predict_with_generate:
+ raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
- assert not (training_args.do_train and training_args.predict_with_generate), \
- "`predict_with_generate` cannot be set as True while training."
+ if training_args.do_train and training_args.predict_with_generate:
+ raise ValueError("`predict_with_generate` cannot be set as True while training.")
- assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
- "Please enable `predict_with_generate` to save model predictions."
+ if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
+ raise ValueError("Please enable `predict_with_generate` to save model predictions.")
- assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
- "Quantization is only compatible with the LoRA method."
+ if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
+ raise ValueError("RM and PPO training can only be performed with the LoRA method.")
- assert not (training_args.max_steps == -1 and data_args.streaming), \
- "Please specify `max_steps` in streaming mode."
+ if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
+ raise ValueError("PPO and DPO stage can only be performed at training.")
- assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
- "Streaming mode does not support evaluation currently."
+ if general_args.stage == "ppo" and model_args.reward_model is None:
+ raise ValueError("Reward model is necessary for PPO training.")
- assert not (general_args.stage == "ppo" and data_args.streaming), \
- "Streaming mode does not suppport PPO training currently."
+ if training_args.max_steps == -1 and data_args.streaming:
+ raise ValueError("Please specify `max_steps` in streaming mode.")
+
+ if general_args.stage == "ppo" and data_args.streaming:
+ raise ValueError("Streaming mode does not suppport PPO training currently.")
+
+ if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
+ raise ValueError("Streaming mode should have an integer val size.")
+
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
- assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
- else:
- assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
- "Quantized model only accepts a single checkpoint."
+ if len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@@ -113,46 +149,48 @@ def get_train_args(
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
- if data_args.dev_ratio > 1e-6 and data_args.streaming:
- logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
- data_args.dev_ratio = 0
-
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
- if model_args.quantization_bit is not None:
- if training_args.fp16:
- model_args.compute_dtype = torch.float16
- elif training_args.bf16:
- model_args.compute_dtype = torch.bfloat16
- else:
- model_args.compute_dtype = torch.float32
+ if training_args.bf16:
+ if not torch.cuda.is_bf16_supported():
+ raise ValueError("Current device does not support bf16 training.")
+ model_args.compute_dtype = torch.bfloat16
+ else:
+ model_args.compute_dtype = torch.float16
+
+ model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:
- logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
+ logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
- bool(training_args.local_rank != -1), training_args.fp16
+ bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed)
- return model_args, data_args, training_args, finetuning_args, general_args
+ return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
-) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
+) -> Tuple[
+ ModelArguments,
+ DataArguments,
+ FinetuningArguments,
+ GeneratingArguments
+]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
- assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
- "Quantization is only compatible with the LoRA method."
+ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
+ raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
- assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
- else:
- assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
- "Quantized model only accepts a single checkpoint."
+ if len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
+ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
+ raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args
diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py
index ae80f32f..058bb740 100644
--- a/src/llmtuner/tuner/core/trainer.py
+++ b/src/llmtuner/tuner/core/trainer.py
@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
-class PeftTrainer(Seq2SeqTrainer):
+class PeftModelMixin:
r"""
- Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
+ Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
- def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
- super().__init__(**kwargs)
- self.finetuning_args = finetuning_args
- self._remove_log()
-
- def _remove_log(self):
- if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
- logger.warning("Previous log file in this folder will be deleted.")
- os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
+ def __init__(self) -> None: # for type checking
+ self.model: PreTrainedModel = None
+ self.tokenizer: "PreTrainedTokenizer" = None
+ self.args: "Seq2SeqTrainingArguments" = None
+ self.finetuning_args: "FinetuningArguments" = None
+ self.state: "TrainerState" = None
+ raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
@@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
+
+
+class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
+ r"""
+ Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
+ """
+
+ def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
+ Seq2SeqTrainer.__init__(self, **kwargs)
+ self.finetuning_args = finetuning_args
diff --git a/src/llmtuner/tuner/dpo/__init__.py b/src/llmtuner/tuner/dpo/__init__.py
new file mode 100644
index 00000000..f2b5cfb5
--- /dev/null
+++ b/src/llmtuner/tuner/dpo/__init__.py
@@ -0,0 +1 @@
+from llmtuner.tuner.dpo.workflow import run_dpo
diff --git a/src/llmtuner/tuner/dpo/collator.py b/src/llmtuner/tuner/dpo/collator.py
new file mode 100644
index 00000000..2f0f4bdc
--- /dev/null
+++ b/src/llmtuner/tuner/dpo/collator.py
@@ -0,0 +1,51 @@
+import torch
+from dataclasses import dataclass
+from typing import Any, Dict, List, Sequence, Tuple
+from transformers import DataCollatorForSeq2Seq
+
+
+@dataclass
+class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
+ r"""
+ Data collator for pairwise data.
+ """
+
+ def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
+ padded_labels = []
+ for feature, (prompt_len, answer_len) in zip(batch, positions):
+ if self.tokenizer.padding_side == "left":
+ start, end = feature.size(0) - answer_len, feature.size(0)
+ else:
+ start, end = prompt_len, answer_len
+ padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
+ padded_tensor[start:end] = feature[start:end]
+ padded_labels.append(padded_tensor)
+ return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
+
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ r"""
+ Pads batched data to the longest sequence in the batch.
+
+ We generate 2 * n examples where the first n examples represent chosen examples and
+ the last n examples represent rejected examples.
+ """
+ concatenated_features = []
+ label_positions = []
+ for key in ("chosen_ids", "rejected_ids"):
+ for feature in features:
+ prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
+ concatenated_features.append({
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (prompt_len + answer_len)
+ })
+ label_positions.append((prompt_len, answer_len))
+
+ batch = self.tokenizer.pad(
+ concatenated_features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ return_tensors=self.return_tensors,
+ )
+ batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
+ return batch
diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py
new file mode 100644
index 00000000..458e99db
--- /dev/null
+++ b/src/llmtuner/tuner/dpo/trainer.py
@@ -0,0 +1,77 @@
+import torch
+from collections import defaultdict
+from peft import PeftModel
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
+from transformers import BatchEncoding, Trainer
+from trl import DPOTrainer
+
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.tuner.core.trainer import PeftModelMixin
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+ from llmtuner.hparams import FinetuningArguments, GeneratingArguments
+
+
+class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
+
+ def __init__(
+ self,
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
+ **kwargs
+ ):
+ self.finetuning_args = finetuning_args
+ self.generating_args = generating_args
+ self.ref_model = ref_model
+ self.use_dpo_data_collator = True # hack to avoid warning
+ self.label_pad_token_id = IGNORE_INDEX
+ self.padding_value = 0
+ self.beta = finetuning_args.dpo_beta
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
+
+ Trainer.__init__(self, **kwargs)
+ if not hasattr(self, "accelerator"):
+ raise AttributeError("Please update `transformers`.")
+
+ if ref_model is not None:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ def concatenated_forward(
+ self,
+ model: Optional[torch.nn.Module] = None,
+ batch: Optional[Dict[str, torch.Tensor]] = None
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
+ unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
+
+ if not torch.is_grad_enabled():
+ unwrapped_model.gradient_checkpointing_disable()
+
+ if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
+ with unwrapped_model.disable_adapter():
+ all_logits = self.model(
+ input_ids=batch_copied["input_ids"],
+ attention_mask=batch_copied["attention_mask"],
+ return_dict=True
+ ).logits.to(torch.float32)
+ else:
+ all_logits = model(
+ input_ids=batch_copied["input_ids"],
+ attention_mask=batch_copied["attention_mask"],
+ return_dict=True
+ ).logits.to(torch.float32)
+
+ if not torch.is_grad_enabled():
+ unwrapped_model.gradient_checkpointing_enable()
+
+ all_logps = self._get_batch_logps(
+ all_logits,
+ batch["labels"],
+ average_log_prob=False
+ )
+ batch_size = batch["input_ids"].size(0) // 2
+ chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
+ chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
+ return chosen_logps, rejected_logps, chosen_logits, rejected_logits
diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py
new file mode 100644
index 00000000..350d64a7
--- /dev/null
+++ b/src/llmtuner/tuner/dpo/workflow.py
@@ -0,0 +1,59 @@
+# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
+
+from copy import deepcopy
+from peft import PeftModel
+from typing import TYPE_CHECKING, Optional, List
+
+from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
+from llmtuner.extras.constants import IGNORE_INDEX
+from llmtuner.extras.ploting import plot_loss
+from llmtuner.tuner.core import load_model_and_tokenizer
+from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
+from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
+
+if TYPE_CHECKING:
+ from transformers import Seq2SeqTrainingArguments, TrainerCallback
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
+
+
+def run_dpo(
+ model_args: "ModelArguments",
+ data_args: "DataArguments",
+ training_args: "Seq2SeqTrainingArguments",
+ finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
+ callbacks: Optional[List["TrainerCallback"]] = None
+):
+ dataset = get_dataset(model_args, data_args)
+ model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
+ dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
+ data_collator = DPODataCollatorWithPadding(
+ tokenizer=tokenizer,
+ label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
+ )
+
+ training_args.remove_unused_columns = False # important for pairwise dataset
+ ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
+
+ # Initialize our Trainer
+ trainer = DPOPeftTrainer(
+ finetuning_args=finetuning_args,
+ generating_args=generating_args,
+ ref_model=ref_model,
+ model=model,
+ args=training_args,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ callbacks=callbacks,
+ **split_dataset(dataset, data_args, training_args)
+ )
+
+ # Training
+ if training_args.do_train:
+ train_result = trainer.train()
+ trainer.log_metrics("train", train_result.metrics)
+ trainer.save_metrics("train", train_result.metrics)
+ trainer.save_state()
+ trainer.save_model()
+ if trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py
index 3392aa4d..f929b6ec 100644
--- a/src/llmtuner/tuner/ppo/trainer.py
+++ b/src/llmtuner/tuner/ppo/trainer.py
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl
from trl import PPOTrainer
-from trl.core import LengthSampler
+from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
- from llmtuner.hparams import FinetuningArguments
+ from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -33,16 +33,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self,
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
+ compute_dtype: torch.dtype,
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.finetuning_args = finetuning_args
+ self.generating_args = generating_args
self.log_callback = callbacks[0]
+ self.compute_dtype = compute_dtype
self.state = TrainerState()
self.control = TrainerControl()
- self._remove_log()
def ppo_train(self, max_target_length: int) -> None:
r"""
@@ -72,14 +75,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
- gen_kwargs = {
- "top_k": 0.0,
- "top_p": 1.0,
- "do_sample": True,
- "pad_token_id": self.tokenizer.pad_token_id,
- "eos_token_id": self.tokenizer.eos_token_id,
- "logits_processor": get_logits_processor()
- }
+ gen_kwargs = self.generating_args.to_dict()
+ gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
+ gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
+ gen_kwargs["logits_processor"] = get_logits_processor()
+
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@@ -185,10 +185,74 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
+ if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2
+ values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
+ @PPODecorators.empty_cuda_cache()
+ def batched_forward_pass(
+ self,
+ model: "AutoModelForCausalLMWithValueHead",
+ queries: torch.Tensor,
+ responses: torch.Tensor,
+ model_inputs: dict,
+ return_logits: Optional[bool] = False
+ ):
+ r"""
+ Calculates model outputs in multiple batches.
+
+ Subclass and override to inject custom behavior.
+ """
+ bs = len(queries)
+ fbs = self.config.mini_batch_size
+ all_logprobs = []
+ all_logits = []
+ all_masks = []
+ all_values = []
+
+ for i in range(math.ceil(bs / fbs)):
+ input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
+ query_batch = queries[i * fbs : (i + 1) * fbs]
+ response_batch = responses[i * fbs : (i + 1) * fbs]
+ input_ids = input_kwargs["input_ids"]
+ attention_mask = input_kwargs["attention_mask"]
+
+ with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
+ logits, _, values = model(**input_kwargs)
+
+ if values.size(0) != input_ids.size(0): # adapt chatglm2
+ values = torch.transpose(values, 0, 1)
+
+ logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
+ masks = torch.zeros_like(attention_mask)
+ masks[:, :-1] = attention_mask[:, 1:]
+
+ for j in range(len(query_batch)):
+ start = len(query_batch[j]) - 1
+ if attention_mask[j, 0] == 0: # offset left padding
+ start += attention_mask[j, :].nonzero()[0]
+ end = start + len(response_batch[j])
+
+ masks[j, :start] = 0
+ masks[j, end:] = 0
+
+ if return_logits:
+ all_logits.append(logits)
+ else:
+ del logits
+ all_values.append(values)
+ all_logprobs.append(logprobs)
+ all_masks.append(masks)
+
+ return (
+ torch.cat(all_logprobs),
+ torch.cat(all_logits)[:, :-1] if return_logits else None,
+ torch.cat(all_values)[:, :-1],
+ torch.cat(all_masks)[:, :-1],
+ )
+
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.
diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py
index 0ca8cbd4..12fcdef1 100644
--- a/src/llmtuner/tuner/ppo/workflow.py
+++ b/src/llmtuner/tuner/ppo/workflow.py
@@ -1,11 +1,9 @@
-# Inspired by:
-# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
+# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
-from typing import TYPE_CHECKING
from trl import PPOConfig
from torch.optim import AdamW
-from typing import Optional, List
+from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
@@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo(
@@ -24,6 +22,7 @@ def run_ppo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
@@ -38,24 +37,30 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1,
- max_grad_norm=training_args.max_grad_norm
+ max_grad_norm=training_args.max_grad_norm,
+ seed=training_args.seed,
+ optimize_cuda_cache=True
)
- optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
- total_train_batch_size = \
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
+ total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
+ )
+ num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
- num_warmup_steps=training_args.warmup_steps,
- num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
+ num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
+ generating_args=generating_args,
callbacks=callbacks,
+ compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,
ref_model=None,
@@ -66,8 +71,10 @@ def run_ppo(
lr_scheduler=lr_scheduler
)
- ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
- ppo_trainer.save_model()
- ppo_trainer.save_state() # must be after save_model
- if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
- plot_loss(training_args.output_dir, keys=["loss", "reward"])
+ # Training
+ if training_args.do_train:
+ ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
+ ppo_trainer.save_model()
+ ppo_trainer.save_state() # must be called after save_model to have a folder
+ if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
+ plot_loss(training_args.output_dir, keys=["loss", "reward"])
diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py
index 2a9f8279..865ec218 100644
--- a/src/llmtuner/tuner/pt/workflow.py
+++ b/src/llmtuner/tuner/pt/workflow.py
@@ -2,10 +2,9 @@
import math
from typing import TYPE_CHECKING, Optional, List
-from transformers import DataCollatorForSeq2Seq
+from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
-from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
@@ -25,10 +24,7 @@ def run_pt(
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
- data_collator = DataCollatorForSeq2Seq(
- tokenizer=tokenizer,
- label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
- )
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = PeftTrainer(
@@ -38,7 +34,7 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
- **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
+ **split_dataset(dataset, data_args, training_args)
)
# Training
@@ -60,6 +56,5 @@ def run_pt(
perplexity = float("inf")
metrics["perplexity"] = perplexity
-
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/tuner/rm/collator.py
index c0da0579..161f003d 100644
--- a/src/llmtuner/tuner/rm/collator.py
+++ b/src/llmtuner/tuner/rm/collator.py
@@ -1,8 +1,10 @@
import torch
+from dataclasses import dataclass
from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding
+@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples.
"""
features = [
- {"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
- for key in ("accept_ids", "reject_ids") for feature in features
+ {
+ "input_ids": feature["prompt_ids"] + feature[key],
+ "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
+ }
+ for key in ("chosen_ids", "rejected_ids") for feature in features
]
return super().__call__(features)
diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py
index e69d48a8..99b4b152 100644
--- a/src/llmtuner/tuner/rm/trainer.py
+++ b/src/llmtuner/tuner/rm/trainer.py
@@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
+ if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2
+ values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py
index 19527ce8..b19a13e6 100644
--- a/src/llmtuner/tuner/rm/workflow.py
+++ b/src/llmtuner/tuner/rm/workflow.py
@@ -39,7 +39,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
- **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
+ **split_dataset(dataset, data_args, training_args)
)
# Training
diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/tuner/sft/metric.py
index 663b037d..812896ee 100644
--- a/src/llmtuner/tuner/sft/metric.py
+++ b/src/llmtuner/tuner/sft/metric.py
@@ -25,7 +25,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
- score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
+ score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -49,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
- score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}
diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py
index 21739ac1..1ddaec1f 100644
--- a/src/llmtuner/tuner/sft/trainer.py
+++ b/src/llmtuner/tuner/sft/trainer.py
@@ -50,9 +50,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
- generated_tokens = (
- generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
- )
+ if generated_tokens is not None:
+ generated_tokens[:, :max(prompt_len, label_len)] = (
+ self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
+ )
return (loss, generated_tokens, labels)
@@ -72,14 +73,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id
else:
- if self.model.config.pad_token_id is not None:
- pad_token_id = self.model.config.pad_token_id
- else:
- raise ValueError("Pad_token_id must be set in the configuration of the model.")
+ raise ValueError("PAD token is required.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
- return padded_tensor.contiguous()
+ return padded_tensor.contiguous() # in contiguous memory
def save_predictions(
self,
diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py
index 693fbd52..ebb16edd 100644
--- a/src/llmtuner/tuner/sft/workflow.py
+++ b/src/llmtuner/tuner/sft/workflow.py
@@ -16,7 +16,7 @@ from llmtuner.extras.logging import reset_logging, get_logger
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
- from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
+ from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -25,6 +25,7 @@ def run_sft(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
+ generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
@@ -50,31 +51,15 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
- **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
+ **split_dataset(dataset, data_args, training_args)
)
# Keyword arguments for `model.generate`
- gen_kwargs = {
- "do_sample": True,
- "top_p": 0.7,
- "max_new_tokens": data_args.max_target_length + 1,
- "temperature": 0.95,
- "logits_processor": get_logits_processor()
- }
- # Detecting last checkpoint.
- last_checkpoint = None
- if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
- last_checkpoint = get_last_checkpoint(training_args.output_dir)
- if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
- raise ValueError(
- f"Output directory ({training_args.output_dir}) already exists and is not empty. "
- "Use --overwrite_output_dir to overcome."
- )
- elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
- logger.info(
- f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
- "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
- )
+ gen_kwargs = generating_args.to_dict()
+ gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
+ gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
+ gen_kwargs["logits_processor"] = get_logits_processor()
+
# Training
if training_args.do_train:
checkpoint = None
diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py
index 99f5d2a9..a4a4c2a1 100644
--- a/src/llmtuner/tuner/tune.py
+++ b/src/llmtuner/tuner/tune.py
@@ -1,35 +1,47 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
+from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
+from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
+logger = get_logger(__name__)
+
+
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
- model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
+ model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft":
- run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
+ run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo":
- run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
+ run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
+ elif general_args.stage == "dpo":
+ run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
+ else:
+ raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
- model_args, _, training_args, finetuning_args, _ = get_train_args(args)
+ model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
- tokenizer.save_pretrained(training_args.output_dir)
+ try:
+ tokenizer.save_pretrained(training_args.output_dir)
+ except:
+ logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py
index d0eb61df..154efa5a 100644
--- a/src/llmtuner/webui/chat.py
+++ b/src/llmtuner/webui/chat.py
@@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str,
quantization_bit: str,
template: str,
- source_prefix: str
+ system_prompt: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
@@ -53,9 +53,9 @@ class WebChatModel(ChatModel):
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
- quantization_bit=int(quantization_bit) if quantization_bit else None,
+ quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
- source_prefix=source_prefix
+ system_prompt=system_prompt
)
super().__init__(args)
@@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
- prefix: str,
+ system: str,
max_new_tokens: int,
top_p: float,
temperature: float
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
- query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
+ query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
response = self.postprocess(response)
diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py
index bf1d18fb..965a690b 100644
--- a/src/llmtuner/webui/common.py
+++ b/src/llmtuner/webui/common.py
@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
-from llmtuner.extras.constants import SUPPORTED_MODELS
+from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
@@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
- return {"last_model": "", "path_dict": {}}
+ return {"lang": "", "last_model": "", "path_dict": {}}
-def save_config(model_name: str, model_path: str) -> None:
+def save_config(lang: str, model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
- user_config["last_model"] = model_name
- user_config["path_dict"][model_name] = model_path
+ user_config["lang"] = lang or user_config["lang"]
+ if model_name:
+ user_config["last_model"] = model_name
+ user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
@@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
+def get_template(model_name: str) -> str:
+ if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
+ return DEFAULT_TEMPLATE[model_name.split("-")[0]]
+ return "default"
+
+
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py
index 5b86f396..32228b8e 100644
--- a/src/llmtuner/webui/components/__init__.py
+++ b/src/llmtuner/webui/components/__init__.py
@@ -1,5 +1,5 @@
from llmtuner.webui.components.top import create_top
-from llmtuner.webui.components.sft import create_sft_tab
+from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py
index 6fcfc652..928a568c 100644
--- a/src/llmtuner/webui/components/chatbot.py
+++ b/src/llmtuner/webui/components/chatbot.py
@@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row():
with gr.Column(scale=4):
- prefix = gr.Textbox(show_label=False)
+ system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
@@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click(
chat_model.predict,
- [chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
+ [chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
@@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
- prefix=prefix,
+ system=system,
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py
index 9787b36a..af19cc41 100644
--- a/src/llmtuner/webui/components/data.py
+++ b/src/llmtuner/webui/components/data.py
@@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
close_btn = gr.Button()
- close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
+ close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn
diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py
index 29b590ae..cbc71daf 100644
--- a/src/llmtuner/webui/components/eval.py
+++ b/src/llmtuner/webui/components/eval.py
@@ -14,13 +14,18 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
- preview_btn = gr.Button(interactive=False, scale=1)
+ data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
- dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
- preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
+ dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
+ data_preview_btn.click(
+ get_preview,
+ [dataset_dir, dataset],
+ [preview_count, preview_samples, preview_box],
+ queue=False
+ )
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -30,38 +35,46 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict = gr.Checkbox(value=True)
with gr.Row():
+ cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
+ with gr.Row():
+ process_bar = gr.Slider(visible=False, interactive=False)
+
with gr.Box():
output_box = gr.Markdown()
- start_btn.click(
- runner.run_eval,
- [
- top_elems["lang"],
- top_elems["model_name"],
- top_elems["checkpoints"],
- top_elems["finetuning_type"],
- top_elems["quantization_bit"],
- top_elems["template"],
- top_elems["source_prefix"],
- dataset_dir,
- dataset,
- max_source_length,
- max_target_length,
- max_samples,
- batch_size,
- predict
- ],
- [output_box]
- )
+ input_components = [
+ top_elems["lang"],
+ top_elems["model_name"],
+ top_elems["checkpoints"],
+ top_elems["finetuning_type"],
+ top_elems["quantization_bit"],
+ top_elems["template"],
+ top_elems["system_prompt"],
+ dataset_dir,
+ dataset,
+ max_source_length,
+ max_target_length,
+ max_samples,
+ batch_size,
+ predict
+ ]
+
+ output_components = [
+ output_box,
+ process_bar
+ ]
+
+ cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
+ start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
- preview_btn=preview_btn,
+ data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
@@ -70,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples=max_samples,
batch_size=batch_size,
predict=predict,
+ cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py
index 40e0323e..14aef162 100644
--- a/src/llmtuner/webui/components/infer.py
+++ b/src/llmtuner/webui/components/infer.py
@@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"],
top_elems["quantization_bit"],
top_elems["template"],
- top_elems["source_prefix"]
+ top_elems["system_prompt"]
],
[info_box]
).then(
diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py
index 4fc5b506..62c1f9c9 100644
--- a/src/llmtuner/webui/components/top.py
+++ b/src/llmtuner/webui/components/top.py
@@ -4,7 +4,7 @@ import gradio as gr
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
-from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
+from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
@@ -15,27 +15,32 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
- lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
+ lang = gr.Dropdown(choices=["en", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
- finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
+ finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
- quantization_bit = gr.Dropdown([8, 4], scale=1)
- template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
- source_prefix = gr.Textbox(scale=2)
+ quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
+ template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
+ system_prompt = gr.Textbox(scale=2)
+
+ lang.change(save_config, [lang, model_name, model_path])
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
+ ).then(
+ get_template, [model_name], [template]
) # do not save config since the below line will save
- model_path.change(save_config, [model_name, model_path])
+
+ model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
@@ -43,7 +48,9 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit]
)
- refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
+ refresh_btn.click(
+ list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
+ )
return dict(
lang=lang,
@@ -55,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab,
quantization_bit=quantization_bit,
template=template,
- source_prefix=source_prefix
+ system_prompt=system_prompt
)
diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/train.py
similarity index 52%
rename from src/llmtuner/webui/components/sft.py
rename to src/llmtuner/webui/components/train.py
index 678693b9..aab512ee 100644
--- a/src/llmtuner/webui/components/sft.py
+++ b/src/llmtuner/webui/components/train.py
@@ -3,7 +3,8 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr
-from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
+from llmtuner.extras.constants import STAGES
+from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@@ -12,17 +13,23 @@ if TYPE_CHECKING:
from llmtuner.webui.runner import Runner
-def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
+def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
+ training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
- preview_btn = gr.Button(interactive=False, scale=1)
+ data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
- dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
- preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
+ dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
+ data_preview_btn.click(
+ get_preview,
+ [dataset_dir, dataset],
+ [preview_count, preview_samples, preview_box],
+ queue=False
+ )
with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -35,10 +42,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown(
- value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
+ choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
)
max_grad_norm = gr.Textbox(value="1.0")
- dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
+ val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row():
@@ -46,20 +53,40 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
+ padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
- lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
+ lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
+ resume_lora_training = gr.Checkbox(value=True, scale=1)
+
+ with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
+ with gr.Row():
+ dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
+ reward_model = gr.Dropdown(scale=2)
+ refresh_btn = gr.Button(scale=1)
+
+ refresh_btn.click(
+ list_checkpoint,
+ [top_elems["model_name"], top_elems["finetuning_type"]],
+ [reward_model],
+ queue=False
+ )
with gr.Row():
+ cmd_preview_btn = gr.Button()
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=3):
- output_dir = gr.Textbox()
+ with gr.Row():
+ output_dir = gr.Textbox()
+
+ with gr.Row():
+ process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box():
output_box = gr.Markdown()
@@ -67,49 +94,59 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Column(scale=1):
loss_viewer = gr.Plot()
- start_btn.click(
- runner.run_train,
- [
- top_elems["lang"],
- top_elems["model_name"],
- top_elems["checkpoints"],
- top_elems["finetuning_type"],
- top_elems["quantization_bit"],
- top_elems["template"],
- top_elems["source_prefix"],
- dataset_dir,
- dataset,
- max_source_length,
- max_target_length,
- learning_rate,
- num_train_epochs,
- max_samples,
- batch_size,
- gradient_accumulation_steps,
- lr_scheduler_type,
- max_grad_norm,
- dev_ratio,
- logging_steps,
- save_steps,
- warmup_steps,
- compute_type,
- lora_rank,
- lora_dropout,
- lora_target,
- output_dir
- ],
- [output_box]
- )
+ input_components = [
+ top_elems["lang"],
+ top_elems["model_name"],
+ top_elems["checkpoints"],
+ top_elems["finetuning_type"],
+ top_elems["quantization_bit"],
+ top_elems["template"],
+ top_elems["system_prompt"],
+ training_stage,
+ dataset_dir,
+ dataset,
+ max_source_length,
+ max_target_length,
+ learning_rate,
+ num_train_epochs,
+ max_samples,
+ batch_size,
+ gradient_accumulation_steps,
+ lr_scheduler_type,
+ max_grad_norm,
+ val_size,
+ logging_steps,
+ save_steps,
+ warmup_steps,
+ compute_type,
+ padding_side,
+ lora_rank,
+ lora_dropout,
+ lora_target,
+ resume_lora_training,
+ dpo_beta,
+ reward_model,
+ output_dir
+ ]
+
+ output_components = [
+ output_box,
+ process_bar
+ ]
+
+ cmd_preview_btn.click(runner.preview_train, input_components, output_components)
+ start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
- output_box.change(
+ process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
+ training_stage=training_stage,
dataset_dir=dataset_dir,
dataset=dataset,
- preview_btn=preview_btn,
+ data_preview_btn=data_preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
@@ -122,16 +159,23 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
- dev_ratio=dev_ratio,
+ val_size=val_size,
advanced_tab=advanced_tab,
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
+ padding_side=padding_side,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
+ resume_lora_training=resume_lora_training,
+ rlhf_tab=rlhf_tab,
+ dpo_beta=dpo_beta,
+ reward_model=reward_model,
+ refresh_btn=refresh_btn,
+ cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py
index 2fb61d37..1a5c0c19 100644
--- a/src/llmtuner/webui/interface.py
+++ b/src/llmtuner/webui/interface.py
@@ -3,7 +3,7 @@ from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
- create_sft_tab,
+ create_train_tab,
create_eval_tab,
create_infer_tab,
create_export_tab,
@@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
- with gr.Tab("SFT"):
- sft_elems = create_sft_tab(top_elems, runner)
+ with gr.Tab("Train"):
+ train_elems = create_train_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
@@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
- elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
+ elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list)
demo.load(
@@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
- lang = gr.Dropdown(choices=["en", "zh"], value="en")
+ lang = gr.Dropdown(choices=["en", "zh"])
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
@@ -67,7 +67,7 @@ def create_web_demo() -> gr.Blocks:
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
- lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
+ lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo
diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py
index ba9a73bd..c4032f39 100644
--- a/src/llmtuner/webui/locales.py
+++ b/src/llmtuner/webui/locales.py
@@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板"
}
},
- "source_prefix": {
+ "system_prompt": {
"en": {
"label": "System prompt (optional)",
"info": "A sequence used as the default system prompt."
@@ -87,6 +87,16 @@ LOCALES = {
"info": "默认使用的系统提示词"
}
},
+ "training_stage": {
+ "en": {
+ "label": "Stage",
+ "info": "The stage to perform in training."
+ },
+ "zh": {
+ "label": "训练阶段",
+ "info": "目前采用的训练方式。"
+ }
+ },
"dataset_dir": {
"en": {
"label": "Data dir",
@@ -105,12 +115,12 @@ LOCALES = {
"label": "数据集"
}
},
- "preview_btn": {
+ "data_preview_btn": {
"en": {
- "value": "Preview"
+ "value": "Preview dataset"
},
"zh": {
- "value": "预览"
+ "value": "预览数据集"
}
},
"preview_count": {
@@ -227,9 +237,9 @@ LOCALES = {
"info": "用于梯度裁剪的范数。"
}
},
- "dev_ratio": {
+ "val_size": {
"en": {
- "label": "Dev ratio",
+ "label": "Val size",
"info": "Proportion of data in the dev set."
},
"zh": {
@@ -277,6 +287,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
+ "padding_side": {
+ "en": {
+ "label": "Padding side",
+ "info": "The side on which the model should have padding applied."
+ },
+ "zh": {
+ "label": "填充位置",
+ "info": "使用左填充或右填充。"
+ }
+ },
"lora_tab": {
"en": {
"label": "LoRA configurations"
@@ -315,6 +335,52 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
+ "resume_lora_training": {
+ "en": {
+ "label": "Resume LoRA training",
+ "info": "Whether to resume training from the last LoRA weights or create new lora weights."
+ },
+ "zh": {
+ "label": "继续上次的训练",
+ "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
+ }
+ },
+ "rlhf_tab": {
+ "en": {
+ "label": "RLHF configurations"
+ },
+ "zh": {
+ "label": "RLHF 参数设置"
+ }
+ },
+ "dpo_beta": {
+ "en": {
+ "label": "DPO beta",
+ "info": "Value of the beta parameter in the DPO loss."
+ },
+ "zh": {
+ "label": "DPO beta 参数",
+ "info": "DPO 损失函数中 beta 超参数大小。"
+ }
+ },
+ "reward_model": {
+ "en": {
+ "label": "Reward model",
+ "info": "Checkpoint of the reward model for PPO training."
+ },
+ "zh": {
+ "label": "奖励模型",
+ "info": "PPO 训练中奖励模型的断点路径。"
+ }
+ },
+ "cmd_preview_btn": {
+ "en": {
+ "value": "Preview command"
+ },
+ "zh": {
+ "value": "预览命令"
+ }
+ },
"start_btn": {
"en": {
"value": "Start"
@@ -389,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。"
}
},
- "prefix": {
+ "system": {
"en": {
"placeholder": "System prompt (optional)"
},
diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py
index c8f797a4..2d5a0a39 100644
--- a/src/llmtuner/webui/manager.py
+++ b/src/llmtuner/webui/manager.py
@@ -12,12 +12,18 @@ class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
- def gen_refresh(self) -> Dict[str, Any]:
+ def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
+
user_config = load_config()
+ if lang:
+ refresh_dict["lang"] = {"value": lang}
+ else:
+ refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
+
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
@@ -26,10 +32,12 @@ class Manager:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {}
- refresh_dict = self.gen_refresh()
+ refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list:
for name, component in elems.items():
- update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
+ update_dict[component] = gr.update(
+ **LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
+ )
return update_dict
diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py
index 763ff614..ac74a4c7 100644
--- a/src/llmtuner/webui/runner.py
+++ b/src/llmtuner/webui/runner.py
@@ -1,10 +1,11 @@
+import gradio as gr
import logging
import os
import threading
import time
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
-from typing import Generator, List, Optional, Tuple
+from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
@@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
-from llmtuner.webui.utils import format_info, get_eval_results
+from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner:
@@ -21,39 +22,36 @@ class Runner:
def __init__(self):
self.aborted = False
self.running = False
+ self.logger_handler = LoggerHandler()
+ self.logger_handler.setLevel(logging.INFO)
+ logging.root.addHandler(self.logger_handler)
+ transformers.logging.add_handler(self.logger_handler)
def set_abort(self):
self.aborted = True
self.running = False
- def initialize(
+ def _initialize(
self, lang: str, model_name: str, dataset: List[str]
- ) -> Tuple[str, str, LoggerHandler, LogCallback]:
+ ) -> str:
if self.running:
- return None, ALERTS["err_conflict"][lang], None, None
+ return ALERTS["err_conflict"][lang]
if not model_name:
- return None, ALERTS["err_no_model"][lang], None, None
+ return ALERTS["err_no_model"][lang]
- model_name_or_path = get_model_path(model_name)
- if not model_name_or_path:
- return None, ALERTS["err_no_path"][lang], None, None
+ if not get_model_path(model_name):
+ return ALERTS["err_no_path"][lang]
if len(dataset) == 0:
- return None, ALERTS["err_no_dataset"][lang], None, None
+ return ALERTS["err_no_dataset"][lang]
self.aborted = False
- self.running = True
+ self.logger_handler.reset()
+ self.trainer_callback = LogCallback(self)
+ return ""
- logger_handler = LoggerHandler()
- logger_handler.setLevel(logging.INFO)
- logging.root.addHandler(logger_handler)
- transformers.logging.add_handler(logger_handler)
- trainer_callback = LogCallback(self)
-
- return model_name_or_path, "", logger_handler, trainer_callback
-
- def finalize(
+ def _finalize(
self, lang: str, finish_info: str
) -> str:
self.running = False
@@ -63,7 +61,7 @@ class Runner:
else:
return finish_info
- def run_train(
+ def _parse_train_args(
self,
lang: str,
model_name: str,
@@ -71,7 +69,8 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
- source_prefix: str,
+ system_prompt: str,
+ training_stage: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -83,24 +82,23 @@ class Runner:
gradient_accumulation_steps: int,
lr_scheduler_type: str,
max_grad_norm: str,
- dev_ratio: float,
+ val_size: float,
logging_steps: int,
save_steps: int,
warmup_steps: int,
compute_type: str,
+ padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
+ resume_lora_training: bool,
+ dpo_beta: float,
+ reward_model: str,
output_dir: str
- ) -> Generator[str, None, None]:
- model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
- if error:
- yield error
- return
-
+ ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
- [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
+ [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
)
else:
checkpoint_dir = None
@@ -109,14 +107,14 @@ class Runner:
args = dict(
stage="sft",
- model_name_or_path=model_name_or_path,
+ model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
- quantization_bit=int(quantization_bit) if quantization_bit else None,
+ quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
- source_prefix=source_prefix,
+ system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@@ -131,39 +129,40 @@ class Runner:
logging_steps=logging_steps,
save_steps=save_steps,
warmup_steps=warmup_steps,
- fp16=(compute_type == "fp16"),
- bf16=(compute_type == "bf16"),
+ padding_side=padding_side,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
+ resume_lora_training=resume_lora_training,
output_dir=output_dir
)
+ args[compute_type] = True
- if dev_ratio > 1e-6:
- args["dev_ratio"] = dev_ratio
+ if training_stage == "Reward Modeling":
+ args["stage"] = "rm"
+ args["resume_lora_training"] = False
+ elif training_stage == "PPO":
+ args["stage"] = "ppo"
+ args["resume_lora_training"] = False
+ args["reward_model"] = reward_model
+ args["padding_side"] = "left"
+ val_size = 0
+ elif training_stage == "DPO":
+ args["stage"] = "dpo"
+ args["resume_lora_training"] = False
+ args["dpo_beta"] = dpo_beta
+ elif training_stage == "Pre-Training":
+ args["stage"] = "pt"
+
+ if val_size > 1e-6:
+ args["val_size"] = val_size
args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True
- run_kwargs = dict(args=args, callbacks=[trainer_callback])
- thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
- thread.start()
+ return lang, model_name, dataset, output_dir, args
- while thread.is_alive():
- time.sleep(1)
- if self.aborted:
- yield ALERTS["info_aborting"][lang]
- else:
- yield format_info(logger_handler.log, trainer_callback)
-
- if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
- finish_info = ALERTS["info_finished"][lang]
- else:
- finish_info = ALERTS["err_failed"][lang]
-
- yield self.finalize(lang, finish_info)
-
- def run_eval(
+ def _parse_eval_args(
self,
lang: str,
model_name: str,
@@ -171,7 +170,7 @@ class Runner:
finetuning_type: str,
quantization_bit: str,
template: str,
- source_prefix: str,
+ system_prompt: str,
dataset_dir: str,
dataset: List[str],
max_source_length: int,
@@ -179,12 +178,7 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
- ) -> Generator[str, None, None]:
- model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
- if error:
- yield error
- return
-
+ ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@@ -196,15 +190,15 @@ class Runner:
args = dict(
stage="sft",
- model_name_or_path=model_name_or_path,
+ model_name_or_path=get_model_path(model_name),
do_eval=True,
overwrite_cache=True,
predict_with_generate=True,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type,
- quantization_bit=int(quantization_bit) if quantization_bit else None,
+ quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template,
- source_prefix=source_prefix,
+ system_prompt=system_prompt,
dataset_dir=dataset_dir,
dataset=",".join(dataset),
max_source_length=max_source_length,
@@ -218,20 +212,72 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
- run_kwargs = dict(args=args, callbacks=[trainer_callback])
+ return lang, model_name, dataset, output_dir, args
+
+ def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ lang, model_name, dataset, _, args = self._parse_train_args(*args)
+ error = self._initialize(lang, model_name, dataset)
+ if error:
+ yield error, gr.update(visible=False)
+ else:
+ yield gen_cmd(args), gr.update(visible=False)
+
+ def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ lang, model_name, dataset, _, args = self._parse_eval_args(*args)
+ error = self._initialize(lang, model_name, dataset)
+ if error:
+ yield error, gr.update(visible=False)
+ else:
+ yield gen_cmd(args), gr.update(visible=False)
+
+ def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
+ lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
+ error = self._initialize(lang, model_name, dataset)
+ if error:
+ yield error, gr.update(visible=False)
+ return
+
+ self.running = True
+ run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
- time.sleep(1)
+ time.sleep(2)
if self.aborted:
- yield ALERTS["info_aborting"][lang]
+ yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
- yield format_info(logger_handler.log, trainer_callback)
+ yield self.logger_handler.log, update_process_bar(self.trainer_callback)
+
+ if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
+ finish_info = ALERTS["info_finished"][lang]
+ else:
+ finish_info = ALERTS["err_failed"][lang]
+
+ yield self._finalize(lang, finish_info), gr.update(visible=False)
+
+ def run_eval(self, *args) -> Generator[str, None, None]:
+ lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
+ error = self._initialize(lang, model_name, dataset)
+ if error:
+ yield error, gr.update(visible=False)
+ return
+
+ self.running = True
+ run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
+ thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
+ thread.start()
+
+ while thread.is_alive():
+ time.sleep(2)
+ if self.aborted:
+ yield ALERTS["info_aborting"][lang], gr.update(visible=False)
+ else:
+ yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
- yield self.finalize(lang, finish_info)
+ yield self._finalize(lang, finish_info), gr.update(visible=False)
diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py
index 7b667c0f..362fa008 100644
--- a/src/llmtuner/webui/utils.py
+++ b/src/llmtuner/webui/utils.py
@@ -15,13 +15,18 @@ if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
-def format_info(log: str, callback: "LogCallback") -> str:
- info = log
- if callback.max_steps:
- info += "Running **{:d}/{:d}**: {} < {}\n".format(
- callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
- )
- return info
+def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
+ if not callback.max_steps:
+ return gr.update(visible=False)
+
+ percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
+ label = "Running {:d}/{:d}: {} < {}".format(
+ callback.cur_steps,
+ callback.max_steps,
+ callback.elapsed_time,
+ callback.remaining_time
+ )
+ return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str:
@@ -57,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True)
+def gen_cmd(args: Dict[str, Any]) -> str:
+ if args.get("do_train", None):
+ args["plot_loss"] = True
+ cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
+ for k, v in args.items():
+ if v is not None and v != "":
+ cmd_lines.append(" --{} {} ".format(k, str(v)))
+ cmd_text = "\\\n".join(cmd_lines)
+ cmd_text = "```bash\n{}\n```".format(cmd_text)
+ return cmd_text
+
+
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)