Merge branch 'main' into main

Former-commit-id: 725290324562e093565fae79a05341ebf64486d5
This commit is contained in:
hoshi-hiyouga 2023-08-18 01:37:23 +08:00 committed by GitHub
commit 49d4ae3704
57 changed files with 1656 additions and 690 deletions

160
.gitignore vendored Normal file
View File

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

119
README.md
View File

@ -12,19 +12,23 @@
## Changelog ## Changelog
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model. [23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. [23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. [23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. [23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. [23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model. [23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
@ -53,25 +57,22 @@
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. - **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. - For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches ## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
- Full-parameter tuning | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- Partial-parameter tuning | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314) | Reward Modeling | | | :white_check_mark: | :white_check_mark: |
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) | PPO Training | | | :white_check_mark: | :white_check_mark: |
- Full-parameter tuning | DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685) - Use `--quantization_bit 4/8` argument to enable QLoRA.
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets ## Provided Datasets
@ -88,7 +89,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -103,7 +103,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling: - For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@ -139,7 +139,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
Currently the web UI only supports training on **a single GPU**. Currently the web UI only supports training on **a single GPU**.
### (Continually) Pre-Training ### Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -207,9 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
Remember to specify `--lora_target W_pack` if you are using Baichuan models. ### Reward Modeling
### Reward Model Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### PPO Training (RLHF) ### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -257,14 +254,40 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss --plot_loss
``` ```
### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### Distributed Training ### Distributed Training
#### Use Huggingface Accelerate
```bash ```bash
accelerate config # configure the environment accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above) accelerate launch src/train_bash.py # arguments (same as above)
``` ```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary> <details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
@ -292,6 +315,44 @@ use_cpu: false
</details> </details>
#### Use DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### Evaluation (BLEU and ROUGE_CHINESE) ### Evaluation (BLEU and ROUGE_CHINESE)
```bash ```bash
@ -390,6 +451,8 @@ Please follow the model licenses to use the corresponding model weights:
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) - [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## Citation ## Citation

View File

@ -12,31 +12,35 @@
## 更新日志 ## 更新日志
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数 [23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。 [23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。 [23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 [23/07/18] 我们开发了支持训练和测试的**一体化浏览器界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 [23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 [23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。 [23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。 [23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。 [23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。 [23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。 [23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 [23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。 [23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。
@ -53,42 +57,38 @@
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。 - **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。 - 对于所有“基座”Base模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna`任意值。但“对话”Chat模型请务必使用对应的模板。
## 微调方法 ## 训练方法
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
- 全参数微调 | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- 部分参数微调 | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314) | 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
- [指令监督微调](https://arxiv.org/abs/2109.01652) | PPO 训练 | | | :white_check_mark: | :white_check_mark: |
- 全参数微调 | DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685) - 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## 数据集 ## 数据集
- 用于二次预训练: - 用于预训练:
- [Wiki Demo (en)](data/wiki_demo.txt) - [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调: - 用于指令监督微调
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -103,7 +103,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练: - 用于奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@ -139,7 +139,6 @@ huggingface-cli login
### 环境搭建(可跳过) ### 环境搭建(可跳过)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
目前网页 UI 仅支持**单卡训练**。 目前网页 UI 仅支持**单卡训练**。
### 二次预训练 ### 预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -207,8 +206,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
### 奖励模型训练 ### 奖励模型训练
```bash ```bash
@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### RLHF 训练 ### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -257,8 +254,34 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss --plot_loss
``` ```
### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练 ### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash ```bash
accelerate config # 首先配置分布式环境 accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上 accelerate launch src/train_bash.py # 参数同上
@ -292,6 +315,44 @@ use_cpu: false
</details> </details>
#### 使用 DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例</summary>
```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### 指标评估BLEU 分数和汉语 ROUGE 分数) ### 指标评估BLEU 分数和汉语 ROUGE 分数)
```bash ```bash
@ -309,7 +370,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate --predict_with_generate
``` ```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` 参数 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型预测 ### 模型预测

Binary file not shown.

Before

Width:  |  Height:  |  Size: 139 KiB

After

Width:  |  Height:  |  Size: 142 KiB

View File

@ -49,26 +49,6 @@
"history": "history" "history": "history"
} }
}, },
"refgpt_zh_p1": {
"file_name": "refgpt_zh_50k_p1.json",
"file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"refgpt_zh_p2": {
"file_name": "refgpt_zh_50k_p2.json",
"file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history"
}
},
"lima": { "lima": {
"file_name": "lima.json", "file_name": "lima.json",
"file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",

View File

@ -1 +0,0 @@
f967a4f6d04a11308a15524aa9a846a19a8d1e83

View File

@ -1 +0,0 @@
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8

View File

@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.4.7 trl>=0.5.0
scipy scipy
sentencepiece sentencepiece
tiktoken tiktoken

View File

@ -7,7 +7,7 @@ def main():
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
# Visit http://localhost:8000/docs for document. print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,4 +6,4 @@ from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.1.5" __version__ = "0.1.6"

View File

@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER: if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content system = prev_messages.pop(0).content
else: else:
prefix = None system = None
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
generate = predict(query, history, prefix, request) generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat( response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
) )
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role=Role.ASSISTANT), delta=DeltaMessage(role=Role.ASSISTANT),
@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue

View File

@ -1,10 +1,9 @@
import torch import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
@ -15,23 +14,21 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix self.system_prompt = data_args.system_prompt
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args( def process_args(
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix system = system or self.system_prompt
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
) )
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0]) prompt_length = len(input_ids[0])
@ -52,8 +49,9 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor(), eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
stopping_criteria=get_stopping_criteria(self.stop_ids) pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor()
)) ))
if max_length: if max_length:
@ -71,10 +69,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:] outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True) response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@ -86,10 +84,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@ -1,48 +1,25 @@
import os import os
import hashlib from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Optional
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset( def get_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments" data_args: "DataArguments"
) -> "Dataset": ) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List["Dataset"] = [] # support multiple datasets all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
@ -92,12 +69,11 @@ def get_dataset(
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix if dataset_attr.system_prompt: # add system prompt
features = None
if data_args.streaming: if data_args.streaming:
features = dataset.features dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
features["prefix"] = Value(dtype="string", id=None) else:
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset) all_datasets.append(dataset)

View File

@ -1,24 +1,25 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain from itertools import chain
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
def preprocess_dataset( def preprocess_dataset(
dataset: "Dataset", dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset": ) -> Union["Dataset", "IterableDataset"]:
column_names = list(dataset.column_names) column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
@ -26,15 +27,16 @@ def preprocess_dataset(
query, response = examples["prompt"][i], examples["response"][i] query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None system = examples["system"][i] if "system" in examples else None
yield query, response, history, prefix yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...` (without <eos>)
if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all") kwargs = dict(allowed_special="all")
else: else:
kwargs = dict(add_special_tokens=False) kwargs = dict(add_special_tokens=False)
tokenized_examples = tokenizer(examples["prompt"], **kwargs) tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@ -46,7 +48,6 @@ def preprocess_dataset(
k: [t[i: i + block_size] for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()
} }
result["labels"] = result["input_ids"].copy()
return result return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
@ -55,10 +56,10 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length max_length = data_args.max_source_length + data_args.max_target_length
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix): for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length: if len(target_ids) > data_args.max_target_length:
@ -77,11 +78,11 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix) source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -95,24 +96,22 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
if len(source_ids) > data_args.max_source_length: if len(prompt_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] prompt_ids = prompt_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length: if len(chosen_ids) > data_args.max_target_length:
accept_ids = accept_ids[:data_args.max_target_length - 1] chosen_ids = chosen_ids[:data_args.max_target_length]
if len(reject_ids) > data_args.max_target_length: if len(rejected_ids) > data_args.max_target_length:
reject_ids = reject_ids[:data_args.max_target_length - 1] rejected_ids = rejected_ids[:data_args.max_target_length]
accept_ids = source_ids + accept_ids model_inputs["prompt_ids"].append(prompt_ids)
reject_ids = source_ids + reject_ids model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs return model_inputs
def print_supervised_dataset_example(example): def print_supervised_dataset_example(example):
@ -124,10 +123,12 @@ def preprocess_dataset(
], skip_special_tokens=False))) ], skip_special_tokens=False)))
def print_pairwise_dataset_example(example): def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"])) print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"])) print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example): def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
@ -166,8 +167,5 @@ def preprocess_dataset(
**kwargs **kwargs
) )
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
return dataset return dataset

View File

@ -1,15 +1,59 @@
from typing import TYPE_CHECKING, Dict import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: logger = get_logger(__name__)
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio) EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else: else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset} return {"train_dataset": dataset}
else: # do_eval or do_predict else: # do_eval or do_predict
return {"eval_dataset": dataset} return {"eval_dataset": dataset}

View File

@ -7,10 +7,16 @@ from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length from transformers.trainer_utils import has_length
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
logger = get_logger(__name__)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
@ -38,6 +44,9 @@ class LogCallback(TrainerCallback):
self.in_training = True self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""

View File

@ -1,13 +1,23 @@
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin" VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json" FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
STAGES = [
"SFT",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b", "LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",
@ -19,6 +29,10 @@ SUPPORTED_MODELS = {
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m", "BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b", "BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1", "BLOOM-7B1": "bigscience/bloom-7b1",
@ -35,16 +49,30 @@ SUPPORTED_MODELS = {
"InternLM-7B": "internlm/internlm-7b", "InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B", "Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat" "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
} }
DEFAULT_MODULE = { DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj", "LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj", "LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value", "BLOOM": "query_key_value",
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",
"Baichuan": "W_pack", "Baichuan": "W_pack",
"InternLM": "q_proj,v_proj", "InternLM": "q_proj,v_proj",
"Qwen": "c_attn" "Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value"
}
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"InternLM": "intern",
"Qwen": "chatml",
"ChatGLM2": "chatglm2"
} }

View File

@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
super().__init__() super().__init__()
self.log = "" self.log = ""
def reset(self):
self.log = ""
def emit(self, record): def emit(self, record):
if record.name == "httpx": if record.name == "httpx":
return return

View File

@ -1,7 +1,6 @@
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
@ -29,37 +28,12 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList: def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor return logits_processor
class StopWordsCriteria(StoppingCriteria):
def __init__(self, stop_ids: List[int]) -> None:
super().__init__()
self.stop_ids = stop_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids])
def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList:
stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(StopWordsCriteria(stop_ids))
return stopping_criteria
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r""" r"""
Returns the number of trainable parameters and number of all parameters in the model. Returns the number of trainable parameters and number of all parameters in the model.
@ -91,7 +65,6 @@ def prepare_model_for_training(
use_gradient_checkpointing: Optional[bool] = True, use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel": ) -> "PreTrainedModel":
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
@ -108,9 +81,6 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name): if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name) output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype input_dtype = output_layer.weight.dtype
@ -138,6 +108,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory. Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
""" """
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory from accelerate.utils import infer_auto_device_map, get_balanced_memory

View File

@ -1,15 +1,22 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import tiktoken
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass @dataclass
class Template: class Template:
prefix: List[Union[str, Dict[str, str]]] prefix: List[Union[str, Dict[str, str]]]
prompt: List[Union[str, Dict[str, str]]] prompt: List[Union[str, Dict[str, str]]]
system: str
sep: List[Union[str, Dict[str, str]]] sep: List[Union[str, Dict[str, str]]]
stop_words: List[str] stop_words: List[str]
use_history: bool use_history: bool
@ -20,18 +27,18 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
""" """
prefix, history = self._format(query, resp, history, prefix) system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, prefix, history) encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0] prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, encoded_pairs[-1][1] return prompt_ids, answer_ids
def encode_multiturn( def encode_multiturn(
self, self,
@ -39,13 +46,13 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
""" """
prefix, history = self._format(query, resp, history, prefix) system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, prefix, history) encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs return encoded_pairs
def _format( def _format(
@ -53,26 +60,29 @@ class Template:
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None system: Optional[str] = None
) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: ) -> Tuple[str, List[Tuple[str, str]]]:
r""" r"""
Aligns inputs to a special format. Aligns inputs to the standard format.
""" """
prefix = [prefix] if prefix else self.prefix # use prefix if provided system = system or self.system # use system if provided
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, resp)] history = history + [(query, resp)]
return prefix, history return system, history
def _get_special_ids( def _get_special_ids(
self, self,
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id: if (
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id] bos_ids = [tokenizer.bos_token_id]
else: else:
bos_ids = [] # bos token is optional bos_ids = [] # bos token is optional
if tokenizer.eos_token_id: if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id] eos_ids = [tokenizer.eos_token_id]
else: else:
raise ValueError("EOS token is required.") raise ValueError("EOS token is required.")
@ -82,35 +92,44 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]], system: str,
history: List[Tuple[str, str]] history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
""" """
bos_ids, eos_ids = self._get_special_ids(tokenizer) bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = [] encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else: else:
prefix_ids = sep_ids prefix_ids = bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids)) encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs return encoded_pairs
def _convert_inputs_to_ids( def _convert_inputs_to_ids(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]], context: List[Union[str, Dict[str, str]]],
query: Optional[str] = "" system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]: ) -> List[int]:
r""" r"""
Converts context to token ids. Converts context to token ids.
""" """
if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all") kwargs = dict(allowed_special="all")
else: else:
kwargs = dict(add_special_tokens=False) kwargs = dict(add_special_tokens=False)
@ -118,12 +137,15 @@ class Template:
token_ids = [] token_ids = []
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
elem = elem.replace("{{query}}", query, 1) elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs) token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else: else:
raise NotImplementedError raise NotImplementedError
return token_ids return token_ids
@ -133,18 +155,19 @@ class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]], system: str,
history: List[Tuple[str, str]] history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
r""" r"""
Encodes formatted inputs to pairs of token ids. Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + query resp + eos
Turn t: bos + query resp + eos
""" """
bos_ids, eos_ids = self._get_special_ids(tokenizer) bos_ids, eos_ids = self._get_special_ids(tokenizer)
encoded_pairs = [] encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
for turn_idx, (query, resp) in enumerate(history): for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has not sep_ids if turn_idx == 0: # llama2 template has no sep_ids
query = prefix[0] + query query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
@ -158,14 +181,16 @@ def register_template(
name: str, name: str,
prefix: List[Union[str, Dict[str, str]]], prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]], prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]], sep: List[Union[str, Dict[str, str]]],
stop_words: List[str], stop_words: Optional[List[str]] = [],
use_history: bool use_history: Optional[bool] = True
) -> None: ) -> None:
template_class = Llama2Template if name == "llama2" else Template template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class( templates[name] = template_class(
prefix=prefix, prefix=prefix,
prompt=prompt, prompt=prompt,
system=system,
sep=sep, sep=sep,
stop_words=stop_words, stop_words=stop_words,
use_history=use_history use_history=use_history
@ -179,13 +204,27 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method additional_special_tokens = template.stop_words
tokenizer.eos_token = template.stop_words[0] if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
return template return template
@ -198,8 +237,8 @@ register_template(
prompt=[ prompt=[
"{{query}}" "{{query}}"
], ],
system="",
sep=[], sep=[],
stop_words=[],
use_history=False use_history=False
) )
@ -210,17 +249,18 @@ Default template.
register_template( register_template(
name="default", name="default",
prefix=[ prefix=[
"A chat between a curious user and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"Human: {{query}}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -232,21 +272,39 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
register_template( register_template(
name="llama2", name="llama2",
prefix=[ prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. " "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system=(
"You are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. " "Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, " "Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. " "racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n" "Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, " "If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" "If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
], ],
prompt=[ prompt=[
"[INST] {{query}} [/INST] " "[INST] {{query}} [/INST] "
], ],
sep=[], system="You are a helpful assistant. 你是一个乐于助人的助手。",
stop_words=[], sep=[]
use_history=True
) )
@ -257,17 +315,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
register_template( register_template(
name="alpaca", name="alpaca",
prefix=[ prefix=[
"Below is an instruction that describes a task. " "{{system}}"
"Write a response that appropriately completes the request."
], ],
prompt=[ prompt=[
"### Instruction:\n{{query}}\n\n### Response:\n" "### Instruction:\n{{query}}\n\n### Response:\n"
], ],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[ sep=[
"\n\n" "\n\n"
], ]
stop_words=[],
use_history=True
) )
@ -278,15 +337,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
register_template( register_template(
name="vicuna", name="vicuna",
prefix=[ prefix=[
"A chat between a curious user and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"USER: {{query}} ASSISTANT: " "USER: {{query}} ASSISTANT: "
], ],
sep=[], system=(
stop_words=[], "A chat between a curious user and an artificial intelligence assistant. "
use_history=True "The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
) )
@ -295,15 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
""" """
register_template( register_template(
name="belle", name="belle",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"Human: {{query}}\n\nBelle: " "Human: {{query}}\n\nBelle: "
], ],
system="",
sep=[ sep=[
"\n\n" "\n\n"
], ]
stop_words=[],
use_history=True
) )
@ -312,15 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly
""" """
register_template( register_template(
name="linly", name="linly",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"User: {{query}}\nBot: " "User: {{query}}\nBot: "
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -329,15 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa
""" """
register_template( register_template(
name="billa", name="billa",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"Human: {{query}}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -346,18 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
""" """
register_template( register_template(
name="ziya", name="ziya",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
{"token": "<human>"}, {"token": "<human>"},
":{{query}}\n", ":{{query}}\n",
{"token": "<bot>"}, {"token": "<bot>"},
":" ":"
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ]
stop_words=[],
use_history=True
) )
@ -367,17 +431,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
register_template( register_template(
name="aquila", name="aquila",
prefix=[ prefix=[
"A chat between a curious human and an artificial intelligence assistant. " "{{system}}"
"The assistant gives helpful, detailed, and polite answers to the human's questions."
], ],
prompt=[ prompt=[
"Human: {{query}}###Assistant: " "Human: {{query}}###Assistant: "
], ],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[ sep=[
"###" "###"
], ]
stop_words=[],
use_history=True
) )
@ -386,19 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
""" """
register_template( register_template(
name="intern", name="intern",
prefix=[], prefix=[
"{{system}}"
],
prompt=[ prompt=[
"<|User|>:{{query}}", "<|User|>:{{query}}",
{"token": "<eoh>"}, {"token": "<eoh>"},
"\n<|Bot|>:" "\n<|Bot|>:"
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>" "<eoa>"
], ]
use_history=True
) )
@ -407,15 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
""" """
register_template( register_template(
name="baichuan", name="baichuan",
prefix=[], prefix=[
prompt=[ "{{system}}",
{"token": "<reserved_102>"}, {"token": "<reserved_102>"} # user token
"{{query}}",
{"token": "<reserved_103>"}
], ],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[], sep=[],
stop_words=[], stop_words=[
use_history=True "<reserved_102>" # user token
]
) )
@ -427,7 +499,8 @@ register_template(
name="starchat", name="starchat",
prefix=[ prefix=[
{"token": "<|system|>"}, {"token": "<|system|>"},
"\n" "\n{{system}}",
{"token": "<|end|>"}
], ],
prompt=[ prompt=[
{"token": "<|user|>"}, {"token": "<|user|>"},
@ -436,13 +509,13 @@ register_template(
"\n", "\n",
{"token": "<|assistant|>"} {"token": "<|assistant|>"}
], ],
system="",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|end|>" "<|end|>"
], ]
use_history=True
) )
@ -453,7 +526,8 @@ register_template(
name="chatml", name="chatml",
prefix=[ prefix=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"system\nYou are a helpful assistant." "system\n{{system}}",
{"token": "<|im_end|>"}
], ],
prompt=[ prompt=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
@ -463,11 +537,31 @@ register_template(
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"assistant\n" "assistant\n"
], ],
system="You are a helpful assistant.",
sep=[ sep=[
"\n" "\n"
], ],
stop_words=[ stop_words=[
"<|im_end|>" "<|im_end|>"
], ]
use_history=True )
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
]
) )

View File

@ -10,7 +10,7 @@ class DatasetAttr:
load_from: str load_from: str
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None system_prompt: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@ -24,7 +24,7 @@ class DatasetAttr:
@dataclass @dataclass
class DataArguments: class DataArguments:
""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field( template: str = field(
@ -86,13 +86,13 @@ class DataArguments:
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
) )
source_prefix: Optional[str] = field( system_prompt: Optional[str] = field(
default=None, default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
) )
dev_ratio: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
) )
def init_for_training(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets
@ -100,12 +100,9 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if self.source_prefix is not None: prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prefix_list = self.source_prefix.split("|") prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
if self.interleave_probs is not None: if self.interleave_probs is not None:
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
@ -126,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None) dataset_sha1=dataset_info[name].get("file_sha1", None)
) )
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class FinetuningArguments: class FinetuningArguments:
""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
@ -14,7 +14,7 @@ class FinetuningArguments:
) )
num_hidden_layers: Optional[int] = field( num_hidden_layers: Optional[int] = field(
default=32, default=32,
metadata={"help": "Number of decoder blocks in the model. \ metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \
@ -25,16 +25,16 @@ class FinetuningArguments:
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."} metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"], \ Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \
InternLM, XVERSE choices: the same as LLaMA."} LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
@ -51,11 +51,19 @@ class FinetuningArguments:
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
InternLM, XVERSE choices: the same as LLaMA."} LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
) )
def __post_init__(self): def __post_init__(self):
@ -72,14 +80,14 @@ class FinetuningArguments:
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f: with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string) f.write(json_string)
@classmethod @classmethod
def load_from_json(cls, json_path: str): def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`.""" r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class GeneralArguments: class GeneralArguments:
""" r"""
Arguments pertaining to which stage we are going to perform. Arguments pertaining to which stage we are going to perform.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."} metadata={"help": "Which stage will be performed in training."}
) )

View File

@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(

View File

@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
@ -43,9 +43,9 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."} metadata={"help": "Whether to use double quantization in int4 training or not."}
) )
compute_dtype: Optional[torch.dtype] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."} metadata={"help": "Adopt scaled rotary positional embeddings."}
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
@ -55,18 +55,33 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
) )
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
) )
hf_auth_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
model_max_length: Optional[int] = field(
default=None,
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
)
def __post_init__(self): def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
raise ValueError("These arguments cannot be specified.")
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None: if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@ -39,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable: if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.") raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
model = model.float() model = model.float()
@ -65,7 +65,7 @@ def init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else: else:
checkpoints_to_merge = model_args.checkpoint_dir checkpoints_to_merge = model_args.checkpoint_dir

View File

@ -1,5 +1,7 @@
import os import os
import math
import torch import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import (
@ -34,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer( def load_model_and_tokenizer(
@ -52,9 +54,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = { config_kwargs = {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@ -69,15 +68,58 @@ def load_model_and_tokenizer(
**config_kwargs **config_kwargs
) )
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]
else: else:
model_to_load = model_args.model_name_or_path model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
is_mergeable = True
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
setattr(config, "fp16", True)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
@ -95,10 +137,10 @@ def load_model_and_tokenizer(
) )
is_mergeable = False is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead). # Load and prepare pre-trained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
@ -107,6 +149,14 @@ def load_model_and_tokenizer(
**config_kwargs **config_kwargs
) )
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files. # Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()
@ -119,10 +169,10 @@ def load_model_and_tokenizer(
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head # Prepare model with valuehead for RLHF
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]): if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
@ -132,15 +182,15 @@ def load_model_and_tokenizer(
}) })
if stage == "ppo": # load reward model if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@ -19,7 +19,7 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args( def parse_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def parse_infer_args( def parse_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args( def get_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@ -67,33 +94,42 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training() data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ if general_args.stage != "sft" and training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages." raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
assert not (training_args.do_train and training_args.predict_with_generate), \ if training_args.do_train and training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True while training." raise ValueError("`predict_with_generate` cannot be set as True while training.")
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
"Please enable `predict_with_generate` to save model predictions." raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
"Quantization is only compatible with the LoRA method." raise ValueError("RM and PPO training can only be performed with the LoRA method.")
assert not (training_args.max_steps == -1 and data_args.streaming), \ if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
"Please specify `max_steps` in streaming mode." raise ValueError("PPO and DPO stage can only be performed at training.")
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ if general_args.stage == "ppo" and model_args.reward_model is None:
"Streaming mode does not support evaluation currently." raise ValueError("Reward model is necessary for PPO training.")
assert not (general_args.stage == "ppo" and data_args.streaming), \ if training_args.max_steps == -1 and data_args.streaming:
"Streaming mode does not suppport PPO training currently." raise ValueError("Please specify `max_steps` in streaming mode.")
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -113,46 +149,48 @@ def get_train_args(
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None: if training_args.bf16:
if training_args.fp16: if not torch.cuda.is_bf16_supported():
model_args.compute_dtype = torch.float16 raise ValueError("Current device does not support bf16 training.")
elif training_args.bf16:
model_args.compute_dtype = torch.bfloat16 model_args.compute_dtype = torch.bfloat16
else: else:
model_args.compute_dtype = torch.float32 model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: # Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format( logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu, training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), training_args.fp16 bool(training_args.local_rank != -1), str(model_args.compute_dtype)
)) ))
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model. # Set seed before initializing model.
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args( def get_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
"Quantization is only compatible with the LoRA method." raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args

View File

@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer): class PeftModelMixin:
r""" r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
""" """
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): def __init__(self) -> None: # for type checking
super().__init__(**kwargs) self.model: PreTrainedModel = None
self.finetuning_args = finetuning_args self.tokenizer: "PreTrainedTokenizer" = None
self._remove_log() self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
def _remove_log(self): self.state: "TrainerState" = None
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): raise AssertionError("Mixin should not be initialized.")
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r""" r"""
@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint) load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@ -0,0 +1 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@ -0,0 +1,77 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@ -0,0 +1,59 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl from transformers import TrainerState, TrainerControl
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
@ -18,7 +18,7 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -33,16 +33,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"], callbacks: List["LogCallback"],
compute_dtype: torch.dtype,
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
self.args = training_args self.args = training_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0] self.log_callback = callbacks[0]
self.compute_dtype = compute_dtype
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
self._remove_log()
def ppo_train(self, max_target_length: int) -> None: def ppo_train(self, max_target_length: int) -> None:
r""" r"""
@ -72,14 +75,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = self.generating_args.to_dict()
"top_k": 0.0, gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
"top_p": 1.0, gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
"do_sample": True, gen_kwargs["logits_processor"] = get_logits_processor()
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -185,10 +185,74 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
return rewards return rewards
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt chatglm2
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
masks[j, :start] = 0
masks[j, end:] = 0
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""
Saves model checkpoint. Saves model checkpoint.

View File

@ -1,11 +1,9 @@
# Inspired by: # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math import math
from typing import TYPE_CHECKING
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo( def run_ppo(
@ -24,6 +22,7 @@ def run_ppo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
@ -38,24 +37,30 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1, ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True
) )
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = \ total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps, num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) num_training_steps=num_training_steps
) )
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = PPOPeftTrainer( ppo_trainer = PPOPeftTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks, callbacks=callbacks,
compute_dtype=model_args.compute_dtype,
config=ppo_config, config=ppo_config,
model=model, model=model,
ref_model=None, ref_model=None,
@ -66,8 +71,10 @@ def run_ppo(
lr_scheduler=lr_scheduler lr_scheduler=lr_scheduler
) )
# Training
if training_args.do_train:
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be after save_model ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -2,10 +2,9 @@
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
@ -25,10 +24,7 @@ def run_pt(
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = PeftTrainer(
@ -38,7 +34,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training
@ -60,6 +56,5 @@ def run_pt(
perplexity = float("inf") perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)

View File

@ -1,8 +1,10 @@
import torch import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [ features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} {
for key in ("accept_ids", "reject_ids") for feature in features "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
""" """
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss

View File

@ -39,7 +39,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training

View File

@ -25,7 +25,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@ -49,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4)) score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()} return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@ -50,8 +50,9 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = ( if generated_tokens is not None:
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None generated_tokens[:, :max(prompt_len, label_len)] = (
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
) )
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
@ -72,14 +73,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
else: else:
if self.model.config.pad_token_id is not None: raise ValueError("PAD token is required.")
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor.contiguous() return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(
self, self,

View File

@ -16,7 +16,7 @@ from llmtuner.extras.logging import reset_logging, get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@ -25,6 +25,7 @@ def run_sft(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
@ -50,31 +51,15 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = generating_args.to_dict()
"do_sample": True, gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
"top_p": 0.7, gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
"max_new_tokens": data_args.max_target_length + 1, gen_kwargs["logits_processor"] = get_logits_processor()
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Training # Training
if training_args.do_train: if training_args.do_train:
checkpoint = None checkpoint = None

View File

@ -1,35 +1,47 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt": if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft": elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, callbacks) run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm": elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks) run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo": elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks) run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _ = get_train_args(args) model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str system_prompt: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@ -53,9 +53,9 @@ class WebChatModel(ChatModel):
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )
super().__init__(args) super().__init__(args)
@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
prefix: str, system: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
response = self.postprocess(response) response = self.postprocess(response)

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -29,12 +29,14 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except: except:
return {"last_model": "", "path_dict": {}} return {"lang": "", "last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None: def save_config(lang: str, model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
if model_name:
user_config["last_model"] = model_name user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(get_config_path(), "w", encoding="utf-8") as f:
@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = [] checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type) save_dir = os.path.join(get_save_dir(model_name), finetuning_type)

View File

@ -1,5 +1,5 @@
from llmtuner.webui.components.top import create_top from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab from llmtuner.webui.components.export import create_export_tab

View File

@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click( submit_btn.click(
chat_model.predict, chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature], [chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
prefix=prefix, system=system,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
close_btn = gr.Button() close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn return preview_box, preview_count, preview_samples, close_btn

View File

@ -14,13 +14,18 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row(): with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@ -30,22 +35,24 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
start_btn.click( input_components = [
runner.run_eval,
[
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"], top_elems["system_prompt"],
dataset_dir, dataset_dir,
dataset, dataset,
max_source_length, max_source_length,
@ -53,15 +60,21 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples, max_samples,
batch_size, batch_size,
predict predict
], ]
[output_box]
) output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
return dict( return dict(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@ -70,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples=max_samples, max_samples=max_samples,
batch_size=batch_size, batch_size=batch_size,
predict=predict, predict=predict,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_box=output_box output_box=output_box

View File

@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"] top_elems["system_prompt"]
], ],
[info_box] [info_box]
).then( ).then(

View File

@ -4,7 +4,7 @@ import gradio as gr
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
@ -15,27 +15,32 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1) lang = gr.Dropdown(choices=["en", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5) checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1) quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2) system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(
get_model_path, [model_name], [model_path] get_model_path, [model_name], [model_path]
).then(
get_template, [model_name], [template]
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change( finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
@ -43,7 +48,9 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit] can_quantize, [finetuning_type], [quantization_bit]
) )
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict( return dict(
lang=lang, lang=lang,
@ -55,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )

View File

@ -3,7 +3,8 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr import gradio as gr
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.extras.constants import STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@ -12,17 +13,23 @@ if TYPE_CHECKING:
from llmtuner.webui.runner import Runner from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row(): with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@ -35,10 +42,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType] choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
) )
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
@ -46,37 +53,56 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2) lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[reward_model],
queue=False
)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Row():
output_dir = gr.Textbox() output_dir = gr.Textbox()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
start_btn.click( input_components = [
runner.run_train,
[
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"], top_elems["system_prompt"],
training_stage,
dataset_dir, dataset_dir,
dataset, dataset,
max_source_length, max_source_length,
@ -88,28 +114,39 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
gradient_accumulation_steps, gradient_accumulation_steps,
lr_scheduler_type, lr_scheduler_type,
max_grad_norm, max_grad_norm,
dev_ratio, val_size,
logging_steps, logging_steps,
save_steps, save_steps,
warmup_steps, warmup_steps,
compute_type, compute_type,
padding_side,
lora_rank, lora_rank,
lora_dropout, lora_dropout,
lora_target, lora_target,
resume_lora_training,
dpo_beta,
reward_model,
output_dir output_dir
], ]
[output_box]
) output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
output_box.change( process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
) )
return dict( return dict(
training_stage=training_stage,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@ -122,16 +159,23 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio, val_size=val_size,
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
compute_type=compute_type, compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab, lora_tab=lora_tab,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target, lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_dir=output_dir, output_dir=output_dir,

View File

@ -3,7 +3,7 @@ from transformers.utils.versions import require_version
from llmtuner.webui.components import ( from llmtuner.webui.components import (
create_top, create_top,
create_sft_tab, create_train_tab,
create_eval_tab, create_eval_tab,
create_infer_tab, create_infer_tab,
create_export_tab, create_export_tab,
@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo: with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top() top_elems = create_top()
with gr.Tab("SFT"): with gr.Tab("Train"):
sft_elems = create_sft_tab(top_elems, runner) train_elems = create_train_tab(top_elems, runner)
with gr.Tab("Evaluate"): with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner) eval_elems = create_eval_tab(top_elems, runner)
@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"): with gr.Tab("Export"):
export_elems = create_export_tab(top_elems) export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list) manager = Manager(elem_list)
demo.load( demo.load(
@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False) chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en") lang = gr.Dropdown(choices=["en", "zh"])
_, _, _, chat_elems = create_chat_box(chat_model, visible=True) _, _, _, chat_elems = create_chat_box(chat_model, visible=True)
@ -67,7 +67,7 @@ def create_web_demo() -> gr.Blocks:
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values())) lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo return demo

View File

@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板" "info": "构建提示词时使用的模板"
} }
}, },
"source_prefix": { "system_prompt": {
"en": { "en": {
"label": "System prompt (optional)", "label": "System prompt (optional)",
"info": "A sequence used as the default system prompt." "info": "A sequence used as the default system prompt."
@ -87,6 +87,16 @@ LOCALES = {
"info": "默认使用的系统提示词" "info": "默认使用的系统提示词"
} }
}, },
"training_stage": {
"en": {
"label": "Stage",
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
},
"dataset_dir": { "dataset_dir": {
"en": { "en": {
"label": "Data dir", "label": "Data dir",
@ -105,12 +115,12 @@ LOCALES = {
"label": "数据集" "label": "数据集"
} }
}, },
"preview_btn": { "data_preview_btn": {
"en": { "en": {
"value": "Preview" "value": "Preview dataset"
}, },
"zh": { "zh": {
"value": "预览" "value": "预览数据集"
} }
}, },
"preview_count": { "preview_count": {
@ -227,9 +237,9 @@ LOCALES = {
"info": "用于梯度裁剪的范数。" "info": "用于梯度裁剪的范数。"
} }
}, },
"dev_ratio": { "val_size": {
"en": { "en": {
"label": "Dev ratio", "label": "Val size",
"info": "Proportion of data in the dev set." "info": "Proportion of data in the dev set."
}, },
"zh": { "zh": {
@ -277,6 +287,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。" "info": "是否启用 FP16 或 BF16 混合精度训练。"
} }
}, },
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": { "lora_tab": {
"en": { "en": {
"label": "LoRA configurations" "label": "LoRA configurations"
@ -315,6 +335,52 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。" "info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
} }
}, },
"resume_lora_training": {
"en": {
"label": "Resume LoRA training",
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
},
"zh": {
"label": "继续上次的训练",
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training."
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": { "start_btn": {
"en": { "en": {
"value": "Start" "value": "Start"
@ -389,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。" "value": "模型未加载,请先加载模型。"
} }
}, },
"prefix": { "system": {
"en": { "en": {
"placeholder": "System prompt (optional)" "placeholder": "System prompt (optional)"
}, },

View File

@ -12,12 +12,18 @@ class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]): def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]: def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = { refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]}, "dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()} "output_dir": {"value": get_time()}
} }
user_config = load_config() user_config = load_config()
if lang:
refresh_dict["lang"] = {"value": lang}
else:
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
if user_config["last_model"]: if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]} refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
@ -26,10 +32,12 @@ class Manager:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {} update_dict = {}
refresh_dict = self.gen_refresh() refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list: for elems in self.elem_list:
for name, component in elems.items(): for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict return update_dict

View File

@ -1,10 +1,11 @@
import gradio as gr
import logging import logging
import os import os
import threading import threading
import time import time
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from typing import Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE
@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner: class Runner:
@ -21,39 +22,36 @@ class Runner:
def __init__(self): def __init__(self):
self.aborted = False self.aborted = False
self.running = False self.running = False
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self): def set_abort(self):
self.aborted = True self.aborted = True
self.running = False self.running = False
def initialize( def _initialize(
self, lang: str, model_name: str, dataset: List[str] self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]: ) -> str:
if self.running: if self.running:
return None, ALERTS["err_conflict"][lang], None, None return ALERTS["err_conflict"][lang]
if not model_name: if not model_name:
return None, ALERTS["err_no_model"][lang], None, None return ALERTS["err_no_model"][lang]
model_name_or_path = get_model_path(model_name) if not get_model_path(model_name):
if not model_name_or_path: return ALERTS["err_no_path"][lang]
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0: if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None return ALERTS["err_no_dataset"][lang]
self.aborted = False self.aborted = False
self.running = True self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
logger_handler = LoggerHandler() def _finalize(
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
self, lang: str, finish_info: str self, lang: str, finish_info: str
) -> str: ) -> str:
self.running = False self.running = False
@ -63,7 +61,7 @@ class Runner:
else: else:
return finish_info return finish_info
def run_train( def _parse_train_args(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@ -71,7 +69,8 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
training_stage: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@ -83,24 +82,23 @@ class Runner:
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
lr_scheduler_type: str, lr_scheduler_type: str,
max_grad_norm: str, max_grad_norm: str,
dev_ratio: float, val_size: float,
logging_steps: int, logging_steps: int,
save_steps: int, save_steps: int,
warmup_steps: int, warmup_steps: int,
compute_type: str, compute_type: str,
padding_side: str,
lora_rank: int, lora_rank: int,
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
resume_lora_training: bool,
dpo_beta: float,
reward_model: str,
output_dir: str output_dir: str
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
) )
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -109,14 +107,14 @@ class Runner:
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=get_model_path(model_name),
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,
@ -131,39 +129,40 @@ class Runner:
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"), padding_side=padding_side,
bf16=(compute_type == "bf16"),
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
output_dir=output_dir output_dir=output_dir
) )
args[compute_type] = True
if dev_ratio > 1e-6: if training_stage == "Reward Modeling":
args["dev_ratio"] = dev_ratio args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps" args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True args["load_best_model_at_end"] = True
run_kwargs = dict(args=args, callbacks=[trainer_callback]) return lang, model_name, dataset, output_dir, args
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive(): def _parse_eval_args(
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
def run_eval(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@ -171,7 +170,7 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@ -179,12 +178,7 @@ class Runner:
max_samples: str, max_samples: str,
batch_size: int, batch_size: int,
predict: bool predict: bool
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@ -196,15 +190,15 @@ class Runner:
args = dict( args = dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=get_model_path(model_name),
do_eval=True, do_eval=True,
overwrite_cache=True, overwrite_cache=True,
predict_with_generate=True, predict_with_generate=True,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,
@ -218,20 +212,72 @@ class Runner:
args.pop("do_eval", None) args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
run_kwargs = dict(args=args, callbacks=[trainer_callback]) return lang, model_name, dataset, output_dir, args
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs) thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start() thread.start()
while thread.is_alive(): while thread.is_alive():
time.sleep(1) time.sleep(2)
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang] yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else: else:
yield format_info(logger_handler.log, trainer_callback) yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")): if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else: else:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info) yield self._finalize(lang, finish_info), gr.update(visible=False)

View File

@ -15,13 +15,18 @@ if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, callback: "LogCallback") -> str: def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
info = log if not callback.max_steps:
if callback.max_steps: return gr.update(visible=False)
info += "Running **{:d}/{:d}**: {} < {}\n".format(
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
) )
return info return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str: def get_time() -> str:
@ -57,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True) return gr.update(interactive=True)
def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items():
if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)