diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..51e9d59e --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md index cca8fdc3..be43a481 100644 --- a/README.md +++ b/README.md @@ -12,19 +12,23 @@ ## Changelog +[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. + +[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature). + [23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model. -[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. +[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode. [23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. [23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. -[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. +[23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. [23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model. -[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. +[23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model. @@ -53,25 +57,22 @@ | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | +| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | -> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. -> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. +- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. +- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models. ## Supported Training Approaches -- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - - Full-parameter tuning - - Partial-parameter tuning - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) - - Full-parameter tuning - - Partial-parameter tuning - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [RLHF](https://arxiv.org/abs/2203.02155) - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) +| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| Reward Modeling | | | :white_check_mark: | :white_check_mark: | +| PPO Training | | | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: | + +- Use `--quantization_bit 4/8` argument to enable QLoRA. ## Provided Datasets @@ -88,7 +89,6 @@ - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Self-cognition (zh)](data/self_cognition.json) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [RefGPT (zh)](https://github.com/sufengniu/RefGPT) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) @@ -103,7 +103,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) -- For reward modelling: +- For reward modeling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -139,7 +139,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t ### Dependence Installation (optional) ```bash -git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning @@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py Currently the web UI only supports training on **a single GPU**. -### (Continually) Pre-Training +### Pre-Training ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -207,9 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -Remember to specify `--lora_target W_pack` if you are using Baichuan models. - -### Reward Model Training +### Reward Modeling ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ --output_dir path_to_rm_checkpoint \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ @@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -### PPO Training (RLHF) +### PPO Training ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -257,14 +254,40 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --plot_loss ``` +### DPO Training + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_your_model \ + --do_train \ + --dataset comparison_gpt4_en \ + --template default \ + --finetuning_type lora \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + ### Distributed Training +#### Use Huggingface Accelerate + ```bash accelerate config # configure the environment accelerate launch src/train_bash.py # arguments (same as above) ``` -
Example configuration for full-tuning with DeepSpeed ZeRO-2 +
Example config.yaml for training with DeepSpeed ZeRO-2 ```yaml compute_environment: LOCAL_MACHINE @@ -292,6 +315,44 @@ use_cpu: false
+#### Use DeepSpeed + +```bash +deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ + --deepspeed ds_config.json \ + ... # arguments (same as above) +``` + +
Example ds_config.json for training with DeepSpeed ZeRO-2 + +```json +{ + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": false, + "contiguous_gradients": true + } +} +``` + +
+ ### Evaluation (BLEU and ROUGE_CHINESE) ```bash @@ -390,6 +451,8 @@ Please follow the model licenses to use the corresponding model weights: - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE) +- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) +- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE) ## Citation diff --git a/README_zh.md b/README_zh.md index d5eca99d..2a84e697 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,31 +12,35 @@ ## 更新日志 -[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。 +[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 -[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。 +[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。 + +[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。 + +[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。 -[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。 +[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。 -[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 +[23/07/18] 我们开发了支持训练和测试的**一体化浏览器界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 -[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 +[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。 -[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 +[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 -[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。 +[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。 [23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。 [23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。 -[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。 +[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。 [23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。 -[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 +[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 [23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。 @@ -53,42 +57,38 @@ | [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern | | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | +| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | -> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。 -> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。 +- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。 +- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。 -## 微调方法 +## 训练方法 -- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - - 全参数微调 - - 部分参数微调 - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [指令监督微调](https://arxiv.org/abs/2109.01652) - - 全参数微调 - - 部分参数微调 - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155) - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) +| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA | +| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | +| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: | +| PPO 训练 | | | :white_check_mark: | :white_check_mark: | +| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: | + +- 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。 ## 数据集 -- 用于二次预训练: +- 用于预训练: - [Wiki Demo (en)](data/wiki_demo.txt) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) -- 用于指令监督微调: +- 用于指令监督微调: - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Self-cognition (zh)](data/self_cognition.json) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [RefGPT (zh)](https://github.com/sufengniu/RefGPT) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) @@ -103,7 +103,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) -- 用于奖励模型训练: +- 用于奖励模型或 DPO 训练: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -139,7 +139,6 @@ huggingface-cli login ### 环境搭建(可跳过) ```bash -git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning @@ -161,7 +160,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py 目前网页 UI 仅支持**单卡训练**。 -### 二次预训练 +### 预训练 ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -207,8 +206,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。 - ### 奖励模型训练 ```bash @@ -222,7 +219,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ --output_dir path_to_rm_checkpoint \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ @@ -233,7 +230,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -### RLHF 训练 +### PPO 训练 ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -257,8 +254,34 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --plot_loss ``` +### DPO 训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_your_model \ + --do_train \ + --dataset comparison_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + ### 多 GPU 分布式训练 +#### 使用 Huggingface Accelerate + ```bash accelerate config # 首先配置分布式环境 accelerate launch src/train_bash.py # 参数同上 @@ -292,7 +315,45 @@ use_cpu: false
-### 指标评估(BLEU分数和汉语ROUGE分数) +#### 使用 DeepSpeed + +```bash +deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \ + --deepspeed ds_config.json \ + ... # 参数同上 +``` + +
使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例 + +```json +{ + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "zero_allow_untested_optimizer": true, + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "overlap_comm": false, + "contiguous_gradients": true + } +} +``` + +
+ +### 指标评估(BLEU 分数和汉语 ROUGE 分数) ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -309,7 +370,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --predict_with_generate ``` -我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128` 参数。 +我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。 ### 模型预测 diff --git a/assets/wechat.jpg b/assets/wechat.jpg index 741cef3c..da80faae 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/data/dataset_info.json b/data/dataset_info.json index 3a3b4e76..3eaf920e 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -49,26 +49,6 @@ "history": "history" } }, - "refgpt_zh_p1": { - "file_name": "refgpt_zh_50k_p1.json", - "file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752", - "columns": { - "prompt": "instruction", - "query": "input", - "response": "output", - "history": "history" - } - }, - "refgpt_zh_p2": { - "file_name": "refgpt_zh_50k_p2.json", - "file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0", - "columns": { - "prompt": "instruction", - "query": "input", - "response": "output", - "history": "history" - } - }, "lima": { "file_name": "lima.json", "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", diff --git a/data/refgpt_zh_50k_p1.json.REMOVED.git-id b/data/refgpt_zh_50k_p1.json.REMOVED.git-id deleted file mode 100644 index 3e8a9e41..00000000 --- a/data/refgpt_zh_50k_p1.json.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -f967a4f6d04a11308a15524aa9a846a19a8d1e83 \ No newline at end of file diff --git a/data/refgpt_zh_50k_p2.json.REMOVED.git-id b/data/refgpt_zh_50k_p2.json.REMOVED.git-id deleted file mode 100644 index a6525b27..00000000 --- a/data/refgpt_zh_50k_p2.json.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9b74b21d..fb5fa72f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.29.1 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 -trl>=0.4.7 +trl>=0.5.0 scipy sentencepiece tiktoken diff --git a/src/api_demo.py b/src/api_demo.py index c0ca9760..777f9dcf 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -7,7 +7,7 @@ def main(): chat_model = ChatModel() app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) - # Visit http://localhost:8000/docs for document. + print("Visit http://localhost:8000/docs for API document.") if __name__ == "__main__": diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index e647b92b..bbc1420b 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -6,4 +6,4 @@ from llmtuner.tuner import export_model, run_exp from llmtuner.webui import create_ui, create_web_demo -__version__ = "0.1.5" +__version__ = "0.1.6" diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 4fc5fc43..41a7fe9a 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -47,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - if request.messages[-1].role != Role.USER: + if len(request.messages) < 1 or request.messages[-1].role != Role.USER: raise HTTPException(status_code=400, detail="Invalid request") - query = request.messages[-1].content + query = request.messages[-1].content prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: - prefix = prev_messages.pop(0).content + system = prev_messages.pop(0).content else: - prefix = None + system = None history = [] if len(prev_messages) % 2 == 0: @@ -64,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI: history.append([prev_messages[i].content, prev_messages[i+1].content]) if request.stream: - generate = predict(query, history, prefix, request) + generate = predict(query, history, system, request) return EventSourceResponse(generate, media_type="text/event-stream") response, (prompt_length, response_length) = chat_model.chat( - query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ) usage = ChatCompletionResponseUsage( @@ -85,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) - async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): + async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role=Role.ASSISTANT), @@ -95,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI: yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in chat_model.stream_chat( - query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens ): if len(new_text) == 0: continue diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 79d3b92d..0d22bb5a 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,10 +1,9 @@ import torch -from types import MethodType from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread -from transformers import PreTrainedModel, TextIteratorStreamer +from transformers import TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria +from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer @@ -15,23 +14,21 @@ class ChatModel: model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model = dispatch_model(self.model) - self.model = self.model.eval() # change to eval mode + self.model = self.model.eval() # enable evaluation mode self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) - self.source_prefix = data_args.source_prefix - self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) - self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen) + self.system_prompt = data_args.system_prompt def process_args( self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Tuple[Dict[str, Any], int]: - prefix = prefix or self.source_prefix + system = system or self.system_prompt prompt, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix + tokenizer=self.tokenizer, query=query, resp="", history=history, system=system ) input_ids = torch.tensor([prompt], device=self.model.device) prompt_length = len(input_ids[0]) @@ -52,8 +49,9 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], - logits_processor=get_logits_processor(), - stopping_criteria=get_stopping_criteria(self.stop_ids) + eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)), + pad_token_id=self.tokenizer.pad_token_id, + logits_processor=get_logits_processor() )) if max_length: @@ -71,10 +69,10 @@ class ChatModel: self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Tuple[str, Tuple[int, int]]: - gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) + gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) generation_output = self.model.generate(**gen_kwargs) outputs = generation_output.tolist()[0][prompt_length:] response = self.tokenizer.decode(outputs, skip_special_tokens=True) @@ -86,10 +84,10 @@ class ChatModel: self, query: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None, + system: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) + gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/dsets/loader.py index 90a4212f..08c35f27 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/dsets/loader.py @@ -1,48 +1,25 @@ import os -import hashlib -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Union -from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset +from datasets import concatenate_datasets, interleave_datasets, load_dataset +from llmtuner.dsets.utils import checksum, EXT2TYPE from llmtuner.extras.logging import get_logger if TYPE_CHECKING: - from datasets import Dataset + from datasets import Dataset, IterableDataset from llmtuner.hparams import ModelArguments, DataArguments logger = get_logger(__name__) -EXT2TYPE = { - "csv": "csv", - "json": "json", - "jsonl": "json", - "txt": "text" -} - - -def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: - if file_sha1 is None: - logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") - return - - if len(data_files) != 1: - logger.warning("Checksum failed: too many files.") - return - - with open(data_files[0], "rb") as f: - sha1 = hashlib.sha1(f.read()).hexdigest() - if sha1 != file_sha1: - logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) - - def get_dataset( model_args: "ModelArguments", data_args: "DataArguments" -) -> "Dataset": +) -> Union["Dataset", "IterableDataset"]: max_samples = data_args.max_samples - all_datasets: List["Dataset"] = [] # support multiple datasets + all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: logger.info("Loading dataset {}...".format(dataset_attr)) @@ -92,12 +69,11 @@ def get_dataset( if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) - if dataset_attr.source_prefix: # add prefix - features = None + if dataset_attr.system_prompt: # add system prompt if data_args.streaming: - features = dataset.features - features["prefix"] = Value(dtype="string", id=None) - dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) + dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) + else: + dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) all_datasets.append(dataset) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index d2150dbc..1fc146f8 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -1,24 +1,25 @@ -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal +import tiktoken +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union from itertools import chain from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.template import get_template_and_fix_tokenizer if TYPE_CHECKING: - from datasets import Dataset + from datasets import Dataset, IterableDataset from transformers import Seq2SeqTrainingArguments from transformers.tokenization_utils import PreTrainedTokenizer from llmtuner.hparams import DataArguments def preprocess_dataset( - dataset: "Dataset", + dataset: Union["Dataset", "IterableDataset"], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"] -) -> "Dataset": - column_names = list(dataset.column_names) +) -> Union["Dataset", "IterableDataset"]: + column_names = list(next(iter(dataset)).keys()) template = get_template_and_fix_tokenizer(data_args.template, tokenizer) def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: @@ -26,15 +27,16 @@ def preprocess_dataset( query, response = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query history = examples["history"][i] if "history" in examples else None - prefix = examples["prefix"][i] if "prefix" in examples else None - yield query, response, history, prefix + system = examples["system"][i] if "system" in examples else None + yield query, response, history, system def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: # build grouped texts with format `X1 X2 X3 ...` (without ) - if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) kwargs = dict(allowed_special="all") else: kwargs = dict(add_special_tokens=False) + tokenized_examples = tokenizer(examples["prompt"], **kwargs) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) @@ -46,7 +48,6 @@ def preprocess_dataset( k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } - result["labels"] = result["input_ids"].copy() return result def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -55,10 +56,10 @@ def preprocess_dataset( model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} max_length = data_args.max_source_length + data_args.max_target_length - for query, response, history, prefix in construct_example(examples): + for query, response, history, system in construct_example(examples): input_ids, labels = [], [] - for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, prefix): + for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system): if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length: @@ -77,11 +78,11 @@ def preprocess_dataset( return model_inputs def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build inputs with format ` X` and labels with format ` Y` + # build inputs with format ` X` and labels with format `Y ` model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} - for query, response, history, prefix in construct_example(examples): - source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, prefix) + for query, response, history, system in construct_example(examples): + source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system) if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] @@ -95,24 +96,22 @@ def preprocess_dataset( return model_inputs def preprocess_pairwise_dataset(examples): - # build input pairs with format ` X Y1 ` and ` X Y2 ` - model_inputs = {"accept_ids": [], "reject_ids": []} - for query, response, history, prefix in construct_example(examples): - source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) - source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} + for query, response, history, system in construct_example(examples): + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) + _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) - if len(source_ids) > data_args.max_source_length: - source_ids = source_ids[:data_args.max_source_length] - if len(accept_ids) > data_args.max_target_length: - accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length: - reject_ids = reject_ids[:data_args.max_target_length - 1] + if len(prompt_ids) > data_args.max_source_length: + prompt_ids = prompt_ids[:data_args.max_source_length] + if len(chosen_ids) > data_args.max_target_length: + chosen_ids = chosen_ids[:data_args.max_target_length] + if len(rejected_ids) > data_args.max_target_length: + rejected_ids = rejected_ids[:data_args.max_target_length] - accept_ids = source_ids + accept_ids - reject_ids = source_ids + reject_ids - - model_inputs["accept_ids"].append(accept_ids) - model_inputs["reject_ids"].append(reject_ids) + model_inputs["prompt_ids"].append(prompt_ids) + model_inputs["chosen_ids"].append(chosen_ids) + model_inputs["rejected_ids"].append(rejected_ids) return model_inputs def print_supervised_dataset_example(example): @@ -124,10 +123,12 @@ def preprocess_dataset( ], skip_special_tokens=False))) def print_pairwise_dataset_example(example): - print("accept_ids:\n{}".format(example["accept_ids"])) - print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) - print("reject_ids:\n{}".format(example["reject_ids"])) - print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) + print("prompt_ids:\n{}".format(example["prompt_ids"])) + print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) + print("chosen_ids:\n{}".format(example["chosen_ids"])) + print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) + print("rejected_ids:\n{}".format(example["rejected_ids"])) + print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) def print_unsupervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) @@ -166,8 +167,5 @@ def preprocess_dataset( **kwargs ) - if data_args.streaming: - dataset = dataset.shuffle(buffer_size=data_args.buffer_size) - print_function(next(iter(dataset))) return dataset diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/dsets/utils.py index 31c48222..bf337014 100644 --- a/src/llmtuner/dsets/utils.py +++ b/src/llmtuner/dsets/utils.py @@ -1,15 +1,59 @@ -from typing import TYPE_CHECKING, Dict +import hashlib +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +from llmtuner.extras.logging import get_logger if TYPE_CHECKING: - from datasets import Dataset + from datasets import Dataset, IterableDataset + from transformers import TrainingArguments + from llmtuner.hparams import DataArguments -def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: - if do_train: - if dev_ratio > 1e-6: # Split the dataset - dataset = dataset.train_test_split(test_size=dev_ratio) - return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} +logger = get_logger(__name__) + + +EXT2TYPE = { + "csv": "csv", + "json": "json", + "jsonl": "json", + "txt": "text" +} + + +def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: + if file_sha1 is None: + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + return + + if len(data_files) != 1: + logger.warning("Checksum failed: too many files.") + return + + with open(data_files[0], "rb") as f: + sha1 = hashlib.sha1(f.read()).hexdigest() + if sha1 != file_sha1: + logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], + data_args: "DataArguments", + training_args: "TrainingArguments" +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + if data_args.streaming: + val_set = dataset.take(int(data_args.val_size)) + train_set = dataset.skip(int(data_args.val_size)) + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": train_set, "eval_dataset": val_set} + else: + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} else: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) return {"train_dataset": dataset} else: # do_eval or do_predict return {"eval_dataset": dataset} diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index d325b0a8..61deae25 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -7,10 +7,16 @@ from datetime import timedelta from transformers import TrainerCallback from transformers.trainer_utils import has_length +from llmtuner.extras.constants import LOG_FILE_NAME +from llmtuner.extras.logging import get_logger + if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl +logger = get_logger(__name__) + + class LogCallback(TrainerCallback): def __init__(self, runner=None): @@ -38,6 +44,9 @@ class LogCallback(TrainerCallback): self.in_training = True self.start_time = time.time() self.max_steps = state.max_steps + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)): + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 6f6dbdd7..cd22943f 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -1,13 +1,23 @@ IGNORE_INDEX = -100 +LOG_FILE_NAME = "trainer_log.jsonl" + VALUE_HEAD_FILE_NAME = "value_head.bin" FINETUNING_ARGS_NAME = "finetuning_args.json" -LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] METHODS = ["full", "freeze", "lora"] +STAGES = [ + "SFT", + "Reward Modeling", + "PPO", + "DPO", + "Pre-Training" +] + SUPPORTED_MODELS = { "LLaMA-7B": "huggyllama/llama-7b", "LLaMA-13B": "huggyllama/llama-13b", @@ -19,6 +29,10 @@ SUPPORTED_MODELS = { "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", + "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", + "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", + "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", + "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", "BLOOM-560M": "bigscience/bloom-560m", "BLOOM-3B": "bigscience/bloom-3b", "BLOOM-7B1": "bigscience/bloom-7b1", @@ -35,16 +49,30 @@ SUPPORTED_MODELS = { "InternLM-7B": "internlm/internlm-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b", "Qwen-7B": "Qwen/Qwen-7B", - "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat" + "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", + "XVERSE-13B": "xverse/XVERSE-13B", + "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" } DEFAULT_MODULE = { "LLaMA": "q_proj,v_proj", "LLaMA2": "q_proj,v_proj", + "ChineseLLaMA2": "q_proj,v_proj", "BLOOM": "query_key_value", "BLOOMZ": "query_key_value", "Falcon": "query_key_value", "Baichuan": "W_pack", "InternLM": "q_proj,v_proj", - "Qwen": "c_attn" + "Qwen": "c_attn", + "XVERSE": "q_proj,v_proj", + "ChatGLM2": "query_key_value" +} + +DEFAULT_TEMPLATE = { + "LLaMA2": "llama2", + "ChineseLLaMA2": "llama2_zh", + "Baichuan": "baichuan", + "InternLM": "intern", + "Qwen": "chatml", + "ChatGLM2": "chatglm2" } diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py index 0b1a68f6..d6f185e6 100644 --- a/src/llmtuner/extras/logging.py +++ b/src/llmtuner/extras/logging.py @@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler): super().__init__() self.log = "" + def reset(self): + self.log = "" + def emit(self, record): if record.name == "httpx": return diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index e1fbb156..b57b1c8f 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,7 +1,6 @@ import torch from typing import TYPE_CHECKING, List, Optional, Tuple - -from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from llmtuner.extras.constants import LAYERNORM_NAMES @@ -29,37 +28,12 @@ class AverageMeter: self.avg = self.sum / self.count -class InvalidScoreLogitsProcessor(LogitsProcessor): - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 0] = 1.0 - return scores - - def get_logits_processor() -> LogitsProcessorList: logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) + logits_processor.append(InfNanRemoveLogitsProcessor()) return logits_processor -class StopWordsCriteria(StoppingCriteria): - - def __init__(self, stop_ids: List[int]) -> None: - super().__init__() - self.stop_ids = stop_ids - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return any([stop_id in input_ids[:, -1] for stop_id in self.stop_ids]) - - -def get_stopping_criteria(stop_ids: List[int]) -> StoppingCriteriaList: - stopping_criteria = StoppingCriteriaList() - stopping_criteria.append(StopWordsCriteria(stop_ids)) - return stopping_criteria - - def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. @@ -91,7 +65,6 @@ def prepare_model_for_training( use_gradient_checkpointing: Optional[bool] = True, layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES ) -> "PreTrainedModel": - for name, param in model.named_parameters(): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): param.data = param.data.to(torch.float32) @@ -108,9 +81,6 @@ def prepare_model_for_training( model.config.use_cache = False # turn off when gradient checkpointing is enabled if finetuning_type != "full" and hasattr(model, output_layer_name): - if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"): - model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728) - output_layer: torch.nn.Linear = getattr(model, output_layer_name) input_dtype = output_layer.weight.dtype @@ -138,6 +108,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": Dispatches a pre-trained model to GPUs with balanced memory. Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 """ + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing + return model + if torch.cuda.device_count() > 1: from accelerate import dispatch_model from accelerate.utils import infer_auto_device_map, get_balanced_memory diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 5b00af03..5d5a03fb 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -1,15 +1,22 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +import tiktoken from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from transformers import PreTrainedTokenizer +logger = get_logger(__name__) + + @dataclass class Template: prefix: List[Union[str, Dict[str, str]]] prompt: List[Union[str, Dict[str, str]]] + system: str sep: List[Union[str, Dict[str, str]]] stop_words: List[str] use_history: bool @@ -20,18 +27,18 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None + system: Optional[str] = None ) -> Tuple[List[int], List[int]]: r""" Returns a single pair of token ids representing prompt and response respectively. """ - prefix, history = self._format(query, resp, history, prefix) - encoded_pairs = self._encode(tokenizer, prefix, history) + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) prompt_ids = [] for query_ids, resp_ids in encoded_pairs[:-1]: prompt_ids = prompt_ids + query_ids + resp_ids - prompt_ids = prompt_ids + encoded_pairs[-1][0] - return prompt_ids, encoded_pairs[-1][1] + prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] + return prompt_ids, answer_ids def encode_multiturn( self, @@ -39,13 +46,13 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None + system: Optional[str] = None ) -> List[Tuple[List[int], List[int]]]: r""" Returns multiple pairs of token ids representing prompts and responses respectively. """ - prefix, history = self._format(query, resp, history, prefix) - encoded_pairs = self._encode(tokenizer, prefix, history) + system, history = self._format(query, resp, history, system) + encoded_pairs = self._encode(tokenizer, system, history) return encoded_pairs def _format( @@ -53,26 +60,29 @@ class Template: query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, - prefix: Optional[str] = None - ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: + system: Optional[str] = None + ) -> Tuple[str, List[Tuple[str, str]]]: r""" - Aligns inputs to a special format. + Aligns inputs to the standard format. """ - prefix = [prefix] if prefix else self.prefix # use prefix if provided + system = system or self.system # use system if provided history = history if (history and self.use_history) else [] history = history + [(query, resp)] - return prefix, history + return system, history def _get_special_ids( self, tokenizer: "PreTrainedTokenizer" ) -> Tuple[List[int], List[int]]: - if tokenizer.bos_token_id: + if ( + tokenizer.bos_token_id is not None + and getattr(tokenizer, "add_bos_token", True) + ): # baichuan-13b has no bos token bos_ids = [tokenizer.bos_token_id] else: bos_ids = [] # bos token is optional - if tokenizer.eos_token_id: + if tokenizer.eos_token_id is not None: eos_ids = [tokenizer.eos_token_id] else: raise ValueError("EOS token is required.") @@ -82,35 +92,44 @@ class Template: def _encode( self, tokenizer: "PreTrainedTokenizer", - prefix: List[Union[str, Dict[str, str]]], + system: str, history: List[Tuple[str, str]] ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + sep + query resp + eos + Turn t: sep + bos + query resp + eos """ bos_ids, eos_ids = self._get_special_ids(tokenizer) sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids + prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) + if len(prefix_ids) != 0: # has prefix + prefix_ids = bos_ids + prefix_ids + sep_ids + else: + prefix_ids = bos_ids else: - prefix_ids = sep_ids - query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + prefix_ids = sep_ids + bos_ids + + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids)) + encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) return encoded_pairs def _convert_inputs_to_ids( self, tokenizer: "PreTrainedTokenizer", context: List[Union[str, Dict[str, str]]], - query: Optional[str] = "" + system: Optional[str] = None, + query: Optional[str] = None, + idx: Optional[str] = None ) -> List[int]: r""" Converts context to token ids. """ - if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen) + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) kwargs = dict(allowed_special="all") else: kwargs = dict(add_special_tokens=False) @@ -118,12 +137,15 @@ class Template: token_ids = [] for elem in context: if isinstance(elem, str): - elem = elem.replace("{{query}}", query, 1) + elem = elem.replace("{{system}}", system, 1) if system is not None else elem + elem = elem.replace("{{query}}", query, 1) if query is not None else elem + elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem token_ids = token_ids + tokenizer.encode(elem, **kwargs) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] else: raise NotImplementedError + return token_ids @@ -133,18 +155,19 @@ class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - prefix: List[Union[str, Dict[str, str]]], + system: str, history: List[Tuple[str, str]] ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + query resp + eos + Turn t: bos + query resp + eos """ bos_ids, eos_ids = self._get_special_ids(tokenizer) encoded_pairs = [] - assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." for turn_idx, (query, resp) in enumerate(history): - if turn_idx == 0: # llama2 template has not sep_ids - query = prefix[0] + query + if turn_idx == 0: # llama2 template has no sep_ids + query = self.prefix[0].replace("{{system}}", system) + query query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) @@ -158,14 +181,16 @@ def register_template( name: str, prefix: List[Union[str, Dict[str, str]]], prompt: List[Union[str, Dict[str, str]]], + system: str, sep: List[Union[str, Dict[str, str]]], - stop_words: List[str], - use_history: bool + stop_words: Optional[List[str]] = [], + use_history: Optional[bool] = True ) -> None: - template_class = Llama2Template if name == "llama2" else Template + template_class = Llama2Template if "llama2" in name else Template templates[name] = template_class( prefix=prefix, prompt=prompt, + system=system, sep=sep, stop_words=stop_words, use_history=use_history @@ -179,13 +204,27 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) - if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method - tokenizer.eos_token = template.stop_words[0] + additional_special_tokens = template.stop_words + if len(template.stop_words): # inplace method + if tokenizer.eos_token_id is not None: + additional_special_tokens.append(tokenizer.eos_token) - if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token + additional_special_tokens.pop(0) + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) - tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words)) + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.pad_token_id is None: + if tokenizer.unk_token_id is not None: + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens)) return template @@ -198,8 +237,8 @@ register_template( prompt=[ "{{query}}" ], + system="", sep=[], - stop_words=[], use_history=False ) @@ -210,17 +249,18 @@ Default template. register_template( name="default", prefix=[ - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." + "{{system}}" ], prompt=[ "Human: {{query}}\nAssistant: " ], + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -232,21 +272,39 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf register_template( name="llama2", prefix=[ - "<>\nYou are a helpful, respectful and honest assistant. " + "<>\n{{system}}\n<>\n\n" + ], + prompt=[ + "[INST] {{query}} [/INST] " + ], + system=( + "You are a helpful, respectful and honest assistant. " "Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, " "racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n" "If a question does not make any sense, or is not factually coherent, " "explain why instead of answering something not correct. " - "If you don't know the answer to a question, please don't share false information.\n<>\n\n" + "If you don't know the answer to a question, please don't share false information." + ), + sep=[] +) + + +r""" +Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2 + https://huggingface.co/ziqingyang/chinese-alpaca-2-7b +""" +register_template( + name="llama2_zh", + prefix=[ + "<>\n{{system}}\n<>\n\n" ], prompt=[ "[INST] {{query}} [/INST] " ], - sep=[], - stop_words=[], - use_history=True + system="You are a helpful assistant. 你是一个乐于助人的助手。", + sep=[] ) @@ -257,17 +315,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff register_template( name="alpaca", prefix=[ - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request." + "{{system}}" ], prompt=[ "### Instruction:\n{{query}}\n\n### Response:\n" ], + system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request." + ), sep=[ "\n\n" - ], - stop_words=[], - use_history=True + ] ) @@ -278,15 +337,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 register_template( name="vicuna", prefix=[ - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's questions." + "{{system}}" ], prompt=[ "USER: {{query}} ASSISTANT: " ], - sep=[], - stop_words=[], - use_history=True + system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + sep=[] ) @@ -295,15 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B """ register_template( name="belle", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "Human: {{query}}\n\nBelle: " ], + system="", sep=[ "\n\n" - ], - stop_words=[], - use_history=True + ] ) @@ -312,15 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly """ register_template( name="linly", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "User: {{query}}\nBot: " ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -329,15 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa """ register_template( name="billa", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "Human: {{query}}\nAssistant: " ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -346,18 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 """ register_template( name="ziya", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ {"token": ""}, ":{{query}}\n", {"token": ""}, ":" ], + system="", sep=[ "\n" - ], - stop_words=[], - use_history=True + ] ) @@ -367,17 +431,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b register_template( name="aquila", prefix=[ - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions." + "{{system}}" ], prompt=[ "Human: {{query}}###Assistant: " ], + system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), sep=[ "###" - ], - stop_words=[], - use_history=True + ] ) @@ -386,19 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b """ register_template( name="intern", - prefix=[], + prefix=[ + "{{system}}" + ], prompt=[ "<|User|>:{{query}}", {"token": ""}, "\n<|Bot|>:" ], + system="", sep=[ "\n" ], stop_words=[ + "", # internlm cannot replace eos token "" - ], - use_history=True + ] ) @@ -407,15 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat """ register_template( name="baichuan", - prefix=[], - prompt=[ - {"token": ""}, - "{{query}}", - {"token": ""} + prefix=[ + "{{system}}", + {"token": ""} # user token ], + prompt=[ + "{{query}}", + {"token": ""} # assistant token + ], + system="", sep=[], - stop_words=[], - use_history=True + stop_words=[ + "" # user token + ] ) @@ -427,7 +499,8 @@ register_template( name="starchat", prefix=[ {"token": "<|system|>"}, - "\n" + "\n{{system}}", + {"token": "<|end|>"} ], prompt=[ {"token": "<|user|>"}, @@ -436,13 +509,13 @@ register_template( "\n", {"token": "<|assistant|>"} ], + system="", sep=[ "\n" ], stop_words=[ "<|end|>" - ], - use_history=True + ] ) @@ -453,7 +526,8 @@ register_template( name="chatml", prefix=[ {"token": "<|im_start|>"}, - "system\nYou are a helpful assistant." + "system\n{{system}}", + {"token": "<|im_end|>"} ], prompt=[ {"token": "<|im_start|>"}, @@ -463,11 +537,31 @@ register_template( {"token": "<|im_start|>"}, "assistant\n" ], + system="You are a helpful assistant.", sep=[ "\n" ], stop_words=[ "<|im_end|>" - ], - use_history=True + ] +) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"}, + "{{system}}" + ], + prompt=[ + "[Round {{idx}}]\n\n问:{{query}}\n\n答:" + ], + system="", + sep=[ + "\n\n" + ] ) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index 60945b60..374d03c6 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -10,7 +10,7 @@ class DatasetAttr: load_from: str dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None - source_prefix: Optional[str] = None + system_prompt: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -24,7 +24,7 @@ class DatasetAttr: @dataclass class DataArguments: - """ + r""" Arguments pertaining to what data we are going to input our model for training and evaluation. """ template: str = field( @@ -86,13 +86,13 @@ class DataArguments: default=True, metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} ) - source_prefix: Optional[str] = field( + system_prompt: Optional[str] = field( default=None, - metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} + metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."} ) - dev_ratio: Optional[float] = field( + val_size: Optional[float] = field( default=0, - metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} ) def init_for_training(self): # support mixing multiple datasets @@ -100,12 +100,9 @@ class DataArguments: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) - if self.source_prefix is not None: - prefix_list = self.source_prefix.split("|") - prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list - assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." - else: - prefix_list = [None] * len(dataset_names) + prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] + prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) + assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1." if self.interleave_probs is not None: self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] @@ -126,12 +123,11 @@ class DataArguments: dataset_sha1=dataset_info[name].get("file_sha1", None) ) - dataset_attr.source_prefix = prefix_list[i] - if "columns" in dataset_info[name]: dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None) + dataset_attr.system_prompt = prompt_list[i] self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 277602ae..d7d651dd 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field @dataclass class FinetuningArguments: - """ + r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( @@ -14,7 +14,7 @@ class FinetuningArguments: ) num_hidden_layers: Optional[int] = field( default=32, - metadata={"help": "Number of decoder blocks in the model. \ + metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \ @@ -25,16 +25,16 @@ class FinetuningArguments: ) num_layer_trainable: Optional[int] = field( default=3, - metadata={"help": "Number of trainable layers for Freeze fine-tuning."} + metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} ) name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( default="mlp", - metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ - LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ + metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ Baichuan choices: [\"mlp\", \"self_attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \ - InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} ) lora_rank: Optional[int] = field( default=8, @@ -51,11 +51,19 @@ class FinetuningArguments: lora_target: Optional[str] = field( default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ - LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ - InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} + ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."} ) def __post_init__(self): @@ -72,14 +80,14 @@ class FinetuningArguments: assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." def save_to_json(self, json_path: str): - """Saves the content of this instance in JSON format inside `json_path`.""" + r"""Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" with open(json_path, "w", encoding="utf-8") as f: f.write(json_string) @classmethod def load_from_json(cls, json_path: str): - """Creates an instance from the content of `json_path`.""" + r"""Creates an instance from the content of `json_path`.""" with open(json_path, "r", encoding="utf-8") as f: text = f.read() return cls(**json.loads(text)) diff --git a/src/llmtuner/hparams/general_args.py b/src/llmtuner/hparams/general_args.py index 397d3019..c0c1a0de 100644 --- a/src/llmtuner/hparams/general_args.py +++ b/src/llmtuner/hparams/general_args.py @@ -4,10 +4,10 @@ from dataclasses import dataclass, field @dataclass class GeneralArguments: - """ + r""" Arguments pertaining to which stage we are going to perform. """ - stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( default="sft", metadata={"help": "Which stage will be performed in training."} ) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index e25ff4b9..f8b935fb 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field @dataclass class GeneratingArguments: - """ + r""" Arguments pertaining to specify the decoding parameters. """ do_sample: Optional[bool] = field( diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 253d9839..dc515f51 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field @dataclass class ModelArguments: - """ + r""" Arguments pertaining to which model/config/tokenizer we are going to fine-tune. """ model_name_or_path: str = field( @@ -43,9 +43,9 @@ class ModelArguments: default=True, metadata={"help": "Whether to use double quantization in int4 training or not."} ) - compute_dtype: Optional[torch.dtype] = field( + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( default=None, - metadata={"help": "Used in quantization configs. Do not specify this argument manually."} + metadata={"help": "Adopt scaled rotary positional embeddings."} ) checkpoint_dir: Optional[str] = field( default=None, @@ -55,18 +55,33 @@ class ModelArguments: default=None, metadata={"help": "Path to the directory containing the checkpoints of the reward model."} ) - resume_lora_training: Optional[bool] = field( - default=True, - metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} - ) plot_loss: Optional[bool] = field( default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} ) + hf_auth_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."} + ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + metadata={"help": "Used in quantization configs. Do not specify this argument manually."} + ) + model_max_length: Optional[int] = field( + default=None, + metadata={"help": "Used in rope scaling. Do not specify this argument manually."} + ) def __post_init__(self): + if self.compute_dtype is not None or self.model_max_length is not None: + raise ValueError("These arguments cannot be specified.") + if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + + if self.use_auth_token == True and self.hf_auth_token is not None: + from huggingface_hub.hf_api import HfFolder # lazy load + HfFolder.save_token(self.hf_auth_token) diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 4afad13a..5db56876 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -39,7 +39,7 @@ def init_adapter( if finetuning_args.finetuning_type == "none" and is_trainable: raise ValueError("You cannot use finetuning_type=none while training.") - if finetuning_args.finetuning_type == "full": + if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") model = model.float() @@ -65,7 +65,7 @@ def init_adapter( assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." - if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights + if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index c06eabfa..4bf767a6 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -1,5 +1,7 @@ import os +import math import torch +from types import MethodType from typing import TYPE_CHECKING, Literal, Optional, Tuple from transformers import ( @@ -34,7 +36,7 @@ check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") -require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") +require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0") def load_model_and_tokenizer( @@ -52,9 +54,6 @@ def load_model_and_tokenizer( logger.warning("Checkpoint is not found at evaluation, load the original model.") finetuning_args = FinetuningArguments(finetuning_type="none") - assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ - "RM and PPO training can only be performed with the LoRA method." - config_kwargs = { "trust_remote_code": True, "cache_dir": model_args.cache_dir, @@ -69,15 +68,58 @@ def load_model_and_tokenizer( **config_kwargs ) - if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": + if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None: model_to_load = model_args.checkpoint_dir[0] else: model_to_load = model_args.model_name_or_path config = AutoConfig.from_pretrained(model_to_load, **config_kwargs) - is_mergeable = True + + if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config + if model_args.compute_dtype == torch.bfloat16: + setattr(config, "bf16", True) + else: + setattr(config, "fp16", True) + + # Set RoPE scaling + if model_args.rope_scaling is not None: + if hasattr(config, "use_dynamic_ntk"): # for Qwen models + if is_trainable: + logger.warning("Qwen model does not support RoPE scaling in training.") + else: + setattr(config, "use_dynamic_ntk", True) + setattr(config, "use_logn_attn", True) + logger.info("Using dynamic NTK scaling.") + + elif hasattr(config, "rope_scaling"): # for LLaMA models + require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0") + + if is_trainable: + if model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info("Using {} scaling strategy and setting scaling factor to {}".format( + model_args.rope_scaling, scaling_factor + )) + + else: + logger.warning("Current model does not support RoPE scaling.") # Quantization configurations (using bitsandbytes library). + is_mergeable = True if model_args.quantization_bit is not None: if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") @@ -95,10 +137,10 @@ def load_model_and_tokenizer( ) is_mergeable = False - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - # Load and prepare pretrained models (without valuehead). + # Load and prepare pre-trained models (without valuehead). model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, @@ -107,6 +149,14 @@ def load_model_and_tokenizer( **config_kwargs ) + # Disable custom generate method (for Qwen) + if "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + # Fix LM head (for ChatGLM2) + if not hasattr(model, "lm_head") and hasattr(model, "transformer"): + setattr(model, "lm_head", model.transformer.output_layer) + # Register auto class to save the custom code files. if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): config.__class__.register_for_auto_class() @@ -119,10 +169,10 @@ def load_model_and_tokenizer( model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) - if stage == "rm" or stage == "ppo": # add value head - model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) + # Prepare model with valuehead for RLHF + if stage == "rm" or stage == "ppo": + model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model) reset_logging() - if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") if load_valuehead_params(model, model_args.checkpoint_dir[-1]): @@ -132,15 +182,15 @@ def load_model_and_tokenizer( }) if stage == "ppo": # load reward model - assert is_trainable, "PPO stage cannot be performed at evaluation." - assert model_args.reward_model is not None, "Reward model is necessary for PPO training." logger.info("Load reward model from {}".format(model_args.reward_model)) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." + # Prepare model for inference if not is_trainable: model.requires_grad_(False) # fix all model params - model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 + infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability + model = model.to(infer_dtype) if model_args.quantization_bit is None else model trainable_params, all_param = count_parameters(model) logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index d872afcc..e039513d 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -19,7 +19,7 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): +def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): @@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) def parse_train_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments +]: parser = HfArgumentParser(( - ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments )) return _parse_args(parser, args) def parse_infer_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: parser = HfArgumentParser(( - ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments )) return _parse_args(parser, args) def get_train_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: - model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments +]: + model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args) # Setup logging if training_args.should_log: @@ -67,33 +94,42 @@ def get_train_args( # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) data_args.init_for_training() - assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." + if general_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") - assert not (training_args.do_train and training_args.predict_with_generate), \ - "`predict_with_generate` cannot be set as True while training." + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") - assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ - "Please enable `predict_with_generate` to save model predictions." + if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." + if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": + raise ValueError("RM and PPO training can only be performed with the LoRA method.") - assert not (training_args.max_steps == -1 and data_args.streaming), \ - "Please specify `max_steps` in streaming mode." + if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: + raise ValueError("PPO and DPO stage can only be performed at training.") - assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ - "Streaming mode does not support evaluation currently." + if general_args.stage == "ppo" and model_args.reward_model is None: + raise ValueError("Reward model is necessary for PPO training.") - assert not (general_args.stage == "ppo" and data_args.streaming), \ - "Streaming mode does not suppport PPO training currently." + if training_args.max_steps == -1 and data_args.streaming: + raise ValueError("Please specify `max_steps` in streaming mode.") + + if general_args.stage == "ppo" and data_args.streaming: + raise ValueError("Streaming mode does not suppport PPO training currently.") + + if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming: + raise ValueError("Streaming mode should have an integer val size.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -113,46 +149,48 @@ def get_train_args( logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") data_args.max_samples = None - if data_args.dev_ratio > 1e-6 and data_args.streaming: - logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") - data_args.dev_ratio = 0 - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning - if model_args.quantization_bit is not None: - if training_args.fp16: - model_args.compute_dtype = torch.float16 - elif training_args.bf16: - model_args.compute_dtype = torch.bfloat16 - else: - model_args.compute_dtype = torch.float32 + if training_args.bf16: + if not torch.cuda.is_bf16_supported(): + raise ValueError("Current device does not support bf16 training.") + model_args.compute_dtype = torch.bfloat16 + else: + model_args.compute_dtype = torch.float16 + + model_args.model_max_length = data_args.max_source_length + data_args.max_target_length # Log on each process the small summary: - logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format( + logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( training_args.local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), training_args.fp16 + bool(training_args.local_rank != -1), str(model_args.compute_dtype) )) logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. transformers.set_seed(training_args.seed) - return model_args, data_args, training_args, finetuning_args, general_args + return model_args, data_args, training_args, finetuning_args, generating_args, general_args def get_infer_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) - assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ - "Quantization is only compatible with the LoRA method." + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - else: - assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ - "Quantized model only accepts a single checkpoint." + if len(model_args.checkpoint_dir) != 1: + raise ValueError("Only LoRA tuning accepts multiple checkpoints.") + elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: + raise ValueError("Quantized model only accepts a single checkpoint.") return model_args, data_args, finetuning_args, generating_args diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index ae80f32f..058bb740 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState from llmtuner.hparams import FinetuningArguments logger = get_logger(__name__) -class PeftTrainer(Seq2SeqTrainer): +class PeftModelMixin: r""" - Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead. """ - def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): - super().__init__(**kwargs) - self.finetuning_args = finetuning_args - self._remove_log() - - def _remove_log(self): - if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): - logger.warning("Previous log file in this folder will be deleted.") - os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) + def __init__(self) -> None: # for type checking + self.model: PreTrainedModel = None + self.tokenizer: "PreTrainedTokenizer" = None + self.args: "Seq2SeqTrainingArguments" = None + self.finetuning_args: "FinetuningArguments" = None + self.state: "TrainerState" = None + raise AssertionError("Mixin should not be initialized.") def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: r""" @@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) else: # freeze/full-tuning load_trainable_params(model, self.state.best_model_checkpoint) + + +class PeftTrainer(PeftModelMixin, Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + """ + + def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): + Seq2SeqTrainer.__init__(self, **kwargs) + self.finetuning_args = finetuning_args diff --git a/src/llmtuner/tuner/dpo/__init__.py b/src/llmtuner/tuner/dpo/__init__.py new file mode 100644 index 00000000..f2b5cfb5 --- /dev/null +++ b/src/llmtuner/tuner/dpo/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/dpo/collator.py b/src/llmtuner/tuner/dpo/collator.py new file mode 100644 index 00000000..2f0f4bdc --- /dev/null +++ b/src/llmtuner/tuner/dpo/collator.py @@ -0,0 +1,51 @@ +import torch +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple +from transformers import DataCollatorForSeq2Seq + + +@dataclass +class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: + padded_labels = [] + for feature, (prompt_len, answer_len) in zip(batch, positions): + if self.tokenizer.padding_side == "left": + start, end = feature.size(0) - answer_len, feature.size(0) + else: + start, end = prompt_len, answer_len + padded_tensor = self.label_pad_token_id * torch.ones_like(feature) + padded_tensor[start:end] = feature[start:end] + padded_labels.append(padded_tensor) + return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + label_positions = [] + for key in ("chosen_ids", "rejected_ids"): + for feature in features: + prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) + concatenated_features.append({ + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (prompt_len + answer_len) + }) + label_positions.append((prompt_len, answer_len)) + + batch = self.tokenizer.pad( + concatenated_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) + return batch diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py new file mode 100644 index 00000000..458e99db --- /dev/null +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -0,0 +1,77 @@ +import torch +from collections import defaultdict +from peft import PeftModel +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from transformers import BatchEncoding, Trainer +from trl import DPOTrainer + +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.tuner.core.trainer import PeftModelMixin + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from llmtuner.hparams import FinetuningArguments, GeneratingArguments + + +class DPOPeftTrainer(PeftModelMixin, DPOTrainer): + + def __init__( + self, + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, + **kwargs + ): + self.finetuning_args = finetuning_args + self.generating_args = generating_args + self.ref_model = ref_model + self.use_dpo_data_collator = True # hack to avoid warning + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.beta = finetuning_args.dpo_beta + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + Trainer.__init__(self, **kwargs) + if not hasattr(self, "accelerator"): + raise AttributeError("Please update `transformers`.") + + if ref_model is not None: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + def concatenated_forward( + self, + model: Optional[torch.nn.Module] = None, + batch: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error + unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) + + if not torch.is_grad_enabled(): + unwrapped_model.gradient_checkpointing_disable() + + if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model + with unwrapped_model.disable_adapter(): + all_logits = self.model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], + return_dict=True + ).logits.to(torch.float32) + else: + all_logits = model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], + return_dict=True + ).logits.to(torch.float32) + + if not torch.is_grad_enabled(): + unwrapped_model.gradient_checkpointing_enable() + + all_logps = self._get_batch_logps( + all_logits, + batch["labels"], + average_log_prob=False + ) + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py new file mode 100644 index 00000000..350d64a7 --- /dev/null +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -0,0 +1,59 @@ +# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py + +from copy import deepcopy +from peft import PeftModel +from typing import TYPE_CHECKING, Optional, List + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding +from llmtuner.tuner.dpo.trainer import DPOPeftTrainer + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + + +def run_dpo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + data_collator = DPODataCollatorWithPadding( + tokenizer=tokenizer, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + ) + + training_args.remove_unused_columns = False # important for pairwise dataset + ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None + + # Initialize our Trainer + trainer = DPOPeftTrainer( + finetuning_args=finetuning_args, + generating_args=generating_args, + ref_model=ref_model, + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args) + ) + + # Training + if training_args.do_train: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 3392aa4d..f929b6ec 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from transformers import TrainerState, TrainerControl from trl import PPOTrainer -from trl.core import LengthSampler +from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor @@ -18,7 +18,7 @@ if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.callbacks import LogCallback - from llmtuner.hparams import FinetuningArguments + from llmtuner.hparams import FinetuningArguments, GeneratingArguments logger = get_logger(__name__) @@ -33,16 +33,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self, training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: List["LogCallback"], + compute_dtype: torch.dtype, **kwargs ): PPOTrainer.__init__(self, **kwargs) self.args = training_args self.finetuning_args = finetuning_args + self.generating_args = generating_args self.log_callback = callbacks[0] + self.compute_dtype = compute_dtype self.state = TrainerState() self.control = TrainerControl() - self._remove_log() def ppo_train(self, max_target_length: int) -> None: r""" @@ -72,14 +75,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") # Keyword arguments for `model.generate` - gen_kwargs = { - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": self.tokenizer.pad_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "logits_processor": get_logits_processor() - } + gen_kwargs = self.generating_args.to_dict() + gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)) + gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) @@ -185,10 +185,74 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): replace_model(unwrapped_model, target="reward") batch = self.prepare_model_inputs(queries, responses) _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) + if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2 + values = torch.transpose(values, 0, 1) rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type replace_model(unwrapped_model, target="default") return rewards + @PPODecorators.empty_cuda_cache() + def batched_forward_pass( + self, + model: "AutoModelForCausalLMWithValueHead", + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: Optional[bool] = False + ): + r""" + Calculates model outputs in multiple batches. + + Subclass and override to inject custom behavior. + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16 + logits, _, values = model(**input_kwargs) + + if values.size(0) != input_ids.size(0): # adapt chatglm2 + values = torch.transpose(values, 0, 1) + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + start = len(query_batch[j]) - 1 + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + + masks[j, :start] = 0 + masks[j, end:] = 0 + + if return_logits: + all_logits.append(logits) + else: + del logits + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + def save_model(self, output_dir: Optional[str] = None) -> None: r""" Saves model checkpoint. diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 0ca8cbd4..12fcdef1 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -1,11 +1,9 @@ -# Inspired by: -# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py +# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py import math -from typing import TYPE_CHECKING from trl import PPOConfig from torch.optim import AdamW -from typing import Optional, List +from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler @@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments def run_ppo( @@ -24,6 +22,7 @@ def run_ppo( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) @@ -38,24 +37,30 @@ def run_ppo( batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps, ppo_epochs=1, - max_grad_norm=training_args.max_grad_norm + max_grad_norm=training_args.max_grad_norm, + seed=training_args.seed, + optimize_cuda_cache=True ) - optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) - total_train_batch_size = \ + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) + total_train_batch_size = ( training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) lr_scheduler = get_scheduler( training_args.lr_scheduler_type, optimizer=optimizer, - num_warmup_steps=training_args.warmup_steps, - num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps ) # Initialize our Trainer ppo_trainer = PPOPeftTrainer( training_args=training_args, finetuning_args=finetuning_args, + generating_args=generating_args, callbacks=callbacks, + compute_dtype=model_args.compute_dtype, config=ppo_config, model=model, ref_model=None, @@ -66,8 +71,10 @@ def run_ppo( lr_scheduler=lr_scheduler ) - ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) - ppo_trainer.save_model() - ppo_trainer.save_state() # must be after save_model - if ppo_trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "reward"]) + # Training + if training_args.do_train: + ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) + ppo_trainer.save_model() + ppo_trainer.save_state() # must be called after save_model to have a folder + if ppo_trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 2a9f8279..865ec218 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -2,10 +2,9 @@ import math from typing import TYPE_CHECKING, Optional, List -from transformers import DataCollatorForSeq2Seq +from transformers import DataCollatorForLanguageModeling from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core.trainer import PeftTrainer @@ -25,10 +24,7 @@ def run_pt( dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") - data_collator = DataCollatorForSeq2Seq( - tokenizer=tokenizer, - label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id - ) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer trainer = PeftTrainer( @@ -38,7 +34,7 @@ def run_pt( tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Training @@ -60,6 +56,5 @@ def run_pt( perplexity = float("inf") metrics["perplexity"] = perplexity - trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/tuner/rm/collator.py index c0da0579..161f003d 100644 --- a/src/llmtuner/tuner/rm/collator.py +++ b/src/llmtuner/tuner/rm/collator.py @@ -1,8 +1,10 @@ import torch +from dataclasses import dataclass from typing import Any, Dict, Sequence from transformers import DataCollatorWithPadding +@dataclass class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): r""" Data collator for pairwise data. @@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): the last n examples represent rejected examples. """ features = [ - {"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} - for key in ("accept_ids", "reject_ids") for feature in features + { + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) + } + for key in ("chosen_ids", "rejected_ids") for feature in features ] return super().__call__(features) diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index e69d48a8..99b4b152 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer): """ batch_size = inputs["input_ids"].size(0) // 2 _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2 + values = torch.transpose(values, 0, 1) r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() return (loss, [loss, r_accept, r_reject]) if return_outputs else loss diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index 19527ce8..b19a13e6 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -39,7 +39,7 @@ def run_rm( data_collator=data_collator, callbacks=callbacks, compute_metrics=compute_accuracy, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Training diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/tuner/sft/metric.py index 663b037d..812896ee 100644 --- a/src/llmtuner/tuner/sft/metric.py +++ b/src/llmtuner/tuner/sft/metric.py @@ -25,7 +25,7 @@ class ComputeMetrics: Uses the model predictions to compute metrics. """ preds, labels = eval_preds - score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) @@ -49,6 +49,5 @@ class ComputeMetrics: bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) score_dict["bleu-4"].append(round(bleu_score * 100, 4)) - score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label)) return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index 21739ac1..1ddaec1f 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -50,9 +50,10 @@ class Seq2SeqPeftTrainer(PeftTrainer): loss, generated_tokens, labels = super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys ) - generated_tokens = ( - generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None - ) + if generated_tokens is not None: + generated_tokens[:, :max(prompt_len, label_len)] = ( + self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)]) + ) return (loss, generated_tokens, labels) @@ -72,14 +73,11 @@ class Seq2SeqPeftTrainer(PeftTrainer): assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." pad_token_id = self.tokenizer.pad_token_id else: - if self.model.config.pad_token_id is not None: - pad_token_id = self.model.config.pad_token_id - else: - raise ValueError("Pad_token_id must be set in the configuration of the model.") + raise ValueError("PAD token is required.") padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding - return padded_tensor.contiguous() + return padded_tensor.contiguous() # in contiguous memory def save_predictions( self, diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 693fbd52..ebb16edd 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -16,7 +16,7 @@ from llmtuner.extras.logging import reset_logging, get_logger if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments logger = get_logger(__name__) @@ -25,6 +25,7 @@ def run_sft( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) @@ -50,31 +51,15 @@ def run_sft( data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, - **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) + **split_dataset(dataset, data_args, training_args) ) # Keyword arguments for `model.generate` - gen_kwargs = { - "do_sample": True, - "top_p": 0.7, - "max_new_tokens": data_args.max_target_length + 1, - "temperature": 0.95, - "logits_processor": get_logits_processor() - } - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: - last_checkpoint = get_last_checkpoint(training_args.output_dir) - if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. " - "Use --overwrite_output_dir to overcome." - ) - elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: - logger.info( - f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " - "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." - ) + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids)) + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + # Training if training_args.do_train: checkpoint = None diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py index 99f5d2a9..a4a4c2a1 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/tuner/tune.py @@ -1,35 +1,47 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.logging import get_logger from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.tuner.pt import run_pt from llmtuner.tuner.sft import run_sft from llmtuner.tuner.rm import run_rm from llmtuner.tuner.ppo import run_ppo +from llmtuner.tuner.dpo import run_dpo if TYPE_CHECKING: from transformers import TrainerCallback +logger = get_logger(__name__) + + def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): - model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args) + model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args) callbacks = [LogCallback()] if callbacks is None else callbacks if general_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif general_args.stage == "sft": - run_sft(model_args, data_args, training_args, finetuning_args, callbacks) + run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif general_args.stage == "rm": run_rm(model_args, data_args, training_args, finetuning_args, callbacks) elif general_args.stage == "ppo": - run_ppo(model_args, data_args, training_args, finetuning_args, callbacks) + run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif general_args.stage == "dpo": + run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) + else: + raise ValueError("Unknown task.") def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): - model_args, _, training_args, finetuning_args, _ = get_train_args(args) + model_args, _, training_args, finetuning_args, _, _ = get_train_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) - tokenizer.save_pretrained(training_args.output_dir) + try: + tokenizer.save_pretrained(training_args.output_dir) + except: + logger.warning("Cannot save tokenizer, please copy the files manually.") if __name__ == "__main__": diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index d0eb61df..154efa5a 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -26,7 +26,7 @@ class WebChatModel(ChatModel): finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str + system_prompt: str ): if self.model is not None: yield ALERTS["err_exists"][lang] @@ -53,9 +53,9 @@ class WebChatModel(ChatModel): model_name_or_path=model_name_or_path, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit else None, + quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix + system_prompt=system_prompt ) super().__init__(args) @@ -73,7 +73,7 @@ class WebChatModel(ChatModel): chatbot: List[Tuple[str, str]], query: str, history: List[Tuple[str, str]], - prefix: str, + system: str, max_new_tokens: int, top_p: float, temperature: float @@ -81,7 +81,7 @@ class WebChatModel(ChatModel): chatbot.append([query, ""]) response = "" for new_text in self.stream_chat( - query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text response = self.postprocess(response) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index bf1d18fb..965a690b 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import SUPPORTED_MODELS +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS DEFAULT_CACHE_DIR = "cache" @@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]: with open(get_config_path(), "r", encoding="utf-8") as f: return json.load(f) except: - return {"last_model": "", "path_dict": {}} + return {"lang": "", "last_model": "", "path_dict": {}} -def save_config(model_name: str, model_path: str) -> None: +def save_config(lang: str, model_name: str, model_path: str) -> None: os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) user_config = load_config() - user_config["last_model"] = model_name - user_config["path_dict"][model_name] = model_path + user_config["lang"] = lang or user_config["lang"] + if model_name: + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path with open(get_config_path(), "w", encoding="utf-8") as f: json.dump(user_config, f, indent=2, ensure_ascii=False) @@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str: return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) +def get_template(model_name: str) -> str: + if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: + return DEFAULT_TEMPLATE[model_name.split("-")[0]] + return "default" + + def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: checkpoints = [] save_dir = os.path.join(get_save_dir(model_name), finetuning_type) diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py index 5b86f396..32228b8e 100644 --- a/src/llmtuner/webui/components/__init__.py +++ b/src/llmtuner/webui/components/__init__.py @@ -1,5 +1,5 @@ from llmtuner.webui.components.top import create_top -from llmtuner.webui.components.sft import create_sft_tab +from llmtuner.webui.components.train import create_train_tab from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.export import create_export_tab diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 6fcfc652..928a568c 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -17,7 +17,7 @@ def create_chat_box( with gr.Row(): with gr.Column(scale=4): - prefix = gr.Textbox(show_label=False) + system = gr.Textbox(show_label=False) query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary") @@ -31,7 +31,7 @@ def create_chat_box( submit_btn.click( chat_model.predict, - [chatbot, query, history, prefix, max_new_tokens, top_p, temperature], + [chatbot, query, history, system, max_new_tokens, top_p, temperature], [chatbot, history], show_progress=True ).then( @@ -41,7 +41,7 @@ def create_chat_box( clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) return chat_box, chatbot, history, dict( - prefix=prefix, + system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn, diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index 9787b36a..af19cc41 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component" close_btn = gr.Button() - close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) + close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) return preview_box, preview_count, preview_samples, close_btn diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index 29b590ae..cbc71daf 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -14,13 +14,18 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) - preview_btn = gr.Button(interactive=False, scale=1) + data_preview_btn = gr.Button(interactive=False, scale=1) preview_box, preview_count, preview_samples, close_btn = create_preview_box() dataset_dir.change(list_dataset, [dataset_dir], [dataset]) - dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) - preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) + dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) + data_preview_btn.click( + get_preview, + [dataset_dir, dataset], + [preview_count, preview_samples, preview_box], + queue=False + ) with gr.Row(): max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) @@ -30,38 +35,46 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict predict = gr.Checkbox(value=True) with gr.Row(): + cmd_preview_btn = gr.Button() start_btn = gr.Button() stop_btn = gr.Button() + with gr.Row(): + process_bar = gr.Slider(visible=False, interactive=False) + with gr.Box(): output_box = gr.Markdown() - start_btn.click( - runner.run_eval, - [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["quantization_bit"], - top_elems["template"], - top_elems["source_prefix"], - dataset_dir, - dataset, - max_source_length, - max_target_length, - max_samples, - batch_size, - predict - ], - [output_box] - ) + input_components = [ + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["system_prompt"], + dataset_dir, + dataset, + max_source_length, + max_target_length, + max_samples, + batch_size, + predict + ] + + output_components = [ + output_box, + process_bar + ] + + cmd_preview_btn.click(runner.preview_eval, input_components, output_components) + start_btn.click(runner.run_eval, input_components, output_components) stop_btn.click(runner.set_abort, queue=False) return dict( dataset_dir=dataset_dir, dataset=dataset, - preview_btn=preview_btn, + data_preview_btn=data_preview_btn, preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn, @@ -70,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict max_samples=max_samples, batch_size=batch_size, predict=predict, + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_box=output_box diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 40e0323e..14aef162 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component" top_elems["finetuning_type"], top_elems["quantization_bit"], top_elems["template"], - top_elems["source_prefix"] + top_elems["system_prompt"] ], [info_box] ).then( diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 4fc5b506..62c1f9c9 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -4,7 +4,7 @@ import gradio as gr from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates -from llmtuner.webui.common import list_checkpoint, get_model_path, save_config +from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config from llmtuner.webui.utils import can_quantize if TYPE_CHECKING: @@ -15,27 +15,32 @@ def create_top() -> Dict[str, "Component"]: available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] with gr.Row(): - lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1) + lang = gr.Dropdown(choices=["en", "zh"], scale=1) model_name = gr.Dropdown(choices=available_models, scale=3) model_path = gr.Textbox(scale=3) with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) + finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) checkpoints = gr.Dropdown(multiselect=True, scale=5) refresh_btn = gr.Button(scale=1) with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown([8, 4], scale=1) - template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1) - source_prefix = gr.Textbox(scale=2) + quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) + template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) + system_prompt = gr.Textbox(scale=2) + + lang.change(save_config, [lang, model_name, model_path]) model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( get_model_path, [model_name], [model_path] + ).then( + get_template, [model_name], [template] ) # do not save config since the below line will save - model_path.change(save_config, [model_name, model_path]) + + model_path.change(save_config, [lang, model_name, model_path]) finetuning_type.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] @@ -43,7 +48,9 @@ def create_top() -> Dict[str, "Component"]: can_quantize, [finetuning_type], [quantization_bit] ) - refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) + refresh_btn.click( + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False + ) return dict( lang=lang, @@ -55,5 +62,5 @@ def create_top() -> Dict[str, "Component"]: advanced_tab=advanced_tab, quantization_bit=quantization_bit, template=template, - source_prefix=source_prefix + system_prompt=system_prompt ) diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/train.py similarity index 52% rename from src/llmtuner/webui/components/sft.py rename to src/llmtuner/webui/components/train.py index 678693b9..aab512ee 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/train.py @@ -3,7 +3,8 @@ from transformers.trainer_utils import SchedulerType import gradio as gr -from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR +from llmtuner.extras.constants import STAGES +from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.utils import can_preview, get_preview, gen_plot @@ -12,17 +13,23 @@ if TYPE_CHECKING: from llmtuner.webui.runner import Runner -def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: +def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): + training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) - preview_btn = gr.Button(interactive=False, scale=1) + data_preview_btn = gr.Button(interactive=False, scale=1) preview_box, preview_count, preview_samples, close_btn = create_preview_box() dataset_dir.change(list_dataset, [dataset_dir], [dataset]) - dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) - preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) + dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) + data_preview_btn.click( + get_preview, + [dataset_dir, dataset], + [preview_count, preview_samples, preview_box], + queue=False + ) with gr.Row(): max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) @@ -35,10 +42,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) lr_scheduler_type = gr.Dropdown( - value="cosine", choices=[scheduler.value for scheduler in SchedulerType] + choices=[scheduler.value for scheduler in SchedulerType], value="cosine" ) max_grad_norm = gr.Textbox(value="1.0") - dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) + val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Row(): @@ -46,20 +53,40 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") + padding_side = gr.Radio(choices=["left", "right"], value="left") with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Row(): lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) - lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1) + lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) lora_target = gr.Textbox(scale=2) + resume_lora_training = gr.Checkbox(value=True, scale=1) + + with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: + with gr.Row(): + dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) + reward_model = gr.Dropdown(scale=2) + refresh_btn = gr.Button(scale=1) + + refresh_btn.click( + list_checkpoint, + [top_elems["model_name"], top_elems["finetuning_type"]], + [reward_model], + queue=False + ) with gr.Row(): + cmd_preview_btn = gr.Button() start_btn = gr.Button() stop_btn = gr.Button() with gr.Row(): with gr.Column(scale=3): - output_dir = gr.Textbox() + with gr.Row(): + output_dir = gr.Textbox() + + with gr.Row(): + process_bar = gr.Slider(visible=False, interactive=False) with gr.Box(): output_box = gr.Markdown() @@ -67,49 +94,59 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ with gr.Column(scale=1): loss_viewer = gr.Plot() - start_btn.click( - runner.run_train, - [ - top_elems["lang"], - top_elems["model_name"], - top_elems["checkpoints"], - top_elems["finetuning_type"], - top_elems["quantization_bit"], - top_elems["template"], - top_elems["source_prefix"], - dataset_dir, - dataset, - max_source_length, - max_target_length, - learning_rate, - num_train_epochs, - max_samples, - batch_size, - gradient_accumulation_steps, - lr_scheduler_type, - max_grad_norm, - dev_ratio, - logging_steps, - save_steps, - warmup_steps, - compute_type, - lora_rank, - lora_dropout, - lora_target, - output_dir - ], - [output_box] - ) + input_components = [ + top_elems["lang"], + top_elems["model_name"], + top_elems["checkpoints"], + top_elems["finetuning_type"], + top_elems["quantization_bit"], + top_elems["template"], + top_elems["system_prompt"], + training_stage, + dataset_dir, + dataset, + max_source_length, + max_target_length, + learning_rate, + num_train_epochs, + max_samples, + batch_size, + gradient_accumulation_steps, + lr_scheduler_type, + max_grad_norm, + val_size, + logging_steps, + save_steps, + warmup_steps, + compute_type, + padding_side, + lora_rank, + lora_dropout, + lora_target, + resume_lora_training, + dpo_beta, + reward_model, + output_dir + ] + + output_components = [ + output_box, + process_bar + ] + + cmd_preview_btn.click(runner.preview_train, input_components, output_components) + start_btn.click(runner.run_train, input_components, output_components) stop_btn.click(runner.set_abort, queue=False) - output_box.change( + process_bar.change( gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False ) return dict( + training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, - preview_btn=preview_btn, + data_preview_btn=data_preview_btn, preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn, @@ -122,16 +159,23 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ gradient_accumulation_steps=gradient_accumulation_steps, lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, - dev_ratio=dev_ratio, + val_size=val_size, advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, compute_type=compute_type, + padding_side=padding_side, lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target, + resume_lora_training=resume_lora_training, + rlhf_tab=rlhf_tab, + dpo_beta=dpo_beta, + reward_model=reward_model, + refresh_btn=refresh_btn, + cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 2fb61d37..1a5c0c19 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -3,7 +3,7 @@ from transformers.utils.versions import require_version from llmtuner.webui.components import ( create_top, - create_sft_tab, + create_train_tab, create_eval_tab, create_infer_tab, create_export_tab, @@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks: with gr.Blocks(title="Web Tuner", css=CSS) as demo: top_elems = create_top() - with gr.Tab("SFT"): - sft_elems = create_sft_tab(top_elems, runner) + with gr.Tab("Train"): + train_elems = create_train_tab(top_elems, runner) with gr.Tab("Evaluate"): eval_elems = create_eval_tab(top_elems, runner) @@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks: with gr.Tab("Export"): export_elems = create_export_tab(top_elems) - elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] + elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems] manager = Manager(elem_list) demo.load( @@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks: chat_model = WebChatModel(lazy_init=False) with gr.Blocks(title="Web Demo", css=CSS) as demo: - lang = gr.Dropdown(choices=["en", "zh"], value="en") + lang = gr.Dropdown(choices=["en", "zh"]) _, _, _, chat_elems = create_chat_box(chat_model, visible=True) @@ -67,7 +67,7 @@ def create_web_demo() -> gr.Blocks: demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) - lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values())) + lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False) return demo diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index ba9a73bd..c4032f39 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -77,7 +77,7 @@ LOCALES = { "info": "构建提示词时使用的模板" } }, - "source_prefix": { + "system_prompt": { "en": { "label": "System prompt (optional)", "info": "A sequence used as the default system prompt." @@ -87,6 +87,16 @@ LOCALES = { "info": "默认使用的系统提示词" } }, + "training_stage": { + "en": { + "label": "Stage", + "info": "The stage to perform in training." + }, + "zh": { + "label": "训练阶段", + "info": "目前采用的训练方式。" + } + }, "dataset_dir": { "en": { "label": "Data dir", @@ -105,12 +115,12 @@ LOCALES = { "label": "数据集" } }, - "preview_btn": { + "data_preview_btn": { "en": { - "value": "Preview" + "value": "Preview dataset" }, "zh": { - "value": "预览" + "value": "预览数据集" } }, "preview_count": { @@ -227,9 +237,9 @@ LOCALES = { "info": "用于梯度裁剪的范数。" } }, - "dev_ratio": { + "val_size": { "en": { - "label": "Dev ratio", + "label": "Val size", "info": "Proportion of data in the dev set." }, "zh": { @@ -277,6 +287,16 @@ LOCALES = { "info": "是否启用 FP16 或 BF16 混合精度训练。" } }, + "padding_side": { + "en": { + "label": "Padding side", + "info": "The side on which the model should have padding applied." + }, + "zh": { + "label": "填充位置", + "info": "使用左填充或右填充。" + } + }, "lora_tab": { "en": { "label": "LoRA configurations" @@ -315,6 +335,52 @@ LOCALES = { "info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。" } }, + "resume_lora_training": { + "en": { + "label": "Resume LoRA training", + "info": "Whether to resume training from the last LoRA weights or create new lora weights." + }, + "zh": { + "label": "继续上次的训练", + "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。" + } + }, + "rlhf_tab": { + "en": { + "label": "RLHF configurations" + }, + "zh": { + "label": "RLHF 参数设置" + } + }, + "dpo_beta": { + "en": { + "label": "DPO beta", + "info": "Value of the beta parameter in the DPO loss." + }, + "zh": { + "label": "DPO beta 参数", + "info": "DPO 损失函数中 beta 超参数大小。" + } + }, + "reward_model": { + "en": { + "label": "Reward model", + "info": "Checkpoint of the reward model for PPO training." + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中奖励模型的断点路径。" + } + }, + "cmd_preview_btn": { + "en": { + "value": "Preview command" + }, + "zh": { + "value": "预览命令" + } + }, "start_btn": { "en": { "value": "Start" @@ -389,7 +455,7 @@ LOCALES = { "value": "模型未加载,请先加载模型。" } }, - "prefix": { + "system": { "en": { "placeholder": "System prompt (optional)" }, diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index c8f797a4..2d5a0a39 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -12,12 +12,18 @@ class Manager: def __init__(self, elem_list: List[Dict[str, Component]]): self.elem_list = elem_list - def gen_refresh(self) -> Dict[str, Any]: + def gen_refresh(self, lang: str) -> Dict[str, Any]: refresh_dict = { "dataset": {"choices": list_dataset()["choices"]}, "output_dir": {"value": get_time()} } + user_config = load_config() + if lang: + refresh_dict["lang"] = {"value": lang} + else: + refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"} + if user_config["last_model"]: refresh_dict["model_name"] = {"value": user_config["last_model"]} refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} @@ -26,10 +32,12 @@ class Manager: def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING update_dict = {} - refresh_dict = self.gen_refresh() + refresh_dict = self.gen_refresh(lang) for elems in self.elem_list: for name, component in elems.items(): - update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) + update_dict[component] = gr.update( + **LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {}) + ) return update_dict diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 763ff614..ac74a4c7 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -1,10 +1,11 @@ +import gradio as gr import logging import os import threading import time import transformers from transformers.trainer import TRAINING_ARGS_NAME -from typing import Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE @@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.locales import ALERTS -from llmtuner.webui.utils import format_info, get_eval_results +from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar class Runner: @@ -21,39 +22,36 @@ class Runner: def __init__(self): self.aborted = False self.running = False + self.logger_handler = LoggerHandler() + self.logger_handler.setLevel(logging.INFO) + logging.root.addHandler(self.logger_handler) + transformers.logging.add_handler(self.logger_handler) def set_abort(self): self.aborted = True self.running = False - def initialize( + def _initialize( self, lang: str, model_name: str, dataset: List[str] - ) -> Tuple[str, str, LoggerHandler, LogCallback]: + ) -> str: if self.running: - return None, ALERTS["err_conflict"][lang], None, None + return ALERTS["err_conflict"][lang] if not model_name: - return None, ALERTS["err_no_model"][lang], None, None + return ALERTS["err_no_model"][lang] - model_name_or_path = get_model_path(model_name) - if not model_name_or_path: - return None, ALERTS["err_no_path"][lang], None, None + if not get_model_path(model_name): + return ALERTS["err_no_path"][lang] if len(dataset) == 0: - return None, ALERTS["err_no_dataset"][lang], None, None + return ALERTS["err_no_dataset"][lang] self.aborted = False - self.running = True + self.logger_handler.reset() + self.trainer_callback = LogCallback(self) + return "" - logger_handler = LoggerHandler() - logger_handler.setLevel(logging.INFO) - logging.root.addHandler(logger_handler) - transformers.logging.add_handler(logger_handler) - trainer_callback = LogCallback(self) - - return model_name_or_path, "", logger_handler, trainer_callback - - def finalize( + def _finalize( self, lang: str, finish_info: str ) -> str: self.running = False @@ -63,7 +61,7 @@ class Runner: else: return finish_info - def run_train( + def _parse_train_args( self, lang: str, model_name: str, @@ -71,7 +69,8 @@ class Runner: finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str, + system_prompt: str, + training_stage: str, dataset_dir: str, dataset: List[str], max_source_length: int, @@ -83,24 +82,23 @@ class Runner: gradient_accumulation_steps: int, lr_scheduler_type: str, max_grad_norm: str, - dev_ratio: float, + val_size: float, logging_steps: int, save_steps: int, warmup_steps: int, compute_type: str, + padding_side: str, lora_rank: int, lora_dropout: float, lora_target: str, + resume_lora_training: bool, + dpo_beta: float, + reward_model: str, output_dir: str - ) -> Generator[str, None, None]: - model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) - if error: - yield error - return - + ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: checkpoint_dir = ",".join( - [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints] ) else: checkpoint_dir = None @@ -109,14 +107,14 @@ class Runner: args = dict( stage="sft", - model_name_or_path=model_name_or_path, + model_name_or_path=get_model_path(model_name), do_train=True, overwrite_cache=True, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit else None, + quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix, + system_prompt=system_prompt, dataset_dir=dataset_dir, dataset=",".join(dataset), max_source_length=max_source_length, @@ -131,39 +129,40 @@ class Runner: logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, - fp16=(compute_type == "fp16"), - bf16=(compute_type == "bf16"), + padding_side=padding_side, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), + resume_lora_training=resume_lora_training, output_dir=output_dir ) + args[compute_type] = True - if dev_ratio > 1e-6: - args["dev_ratio"] = dev_ratio + if training_stage == "Reward Modeling": + args["stage"] = "rm" + args["resume_lora_training"] = False + elif training_stage == "PPO": + args["stage"] = "ppo" + args["resume_lora_training"] = False + args["reward_model"] = reward_model + args["padding_side"] = "left" + val_size = 0 + elif training_stage == "DPO": + args["stage"] = "dpo" + args["resume_lora_training"] = False + args["dpo_beta"] = dpo_beta + elif training_stage == "Pre-Training": + args["stage"] = "pt" + + if val_size > 1e-6: + args["val_size"] = val_size args["evaluation_strategy"] = "steps" args["eval_steps"] = save_steps args["load_best_model_at_end"] = True - run_kwargs = dict(args=args, callbacks=[trainer_callback]) - thread = threading.Thread(target=run_exp, kwargs=run_kwargs) - thread.start() + return lang, model_name, dataset, output_dir, args - while thread.is_alive(): - time.sleep(1) - if self.aborted: - yield ALERTS["info_aborting"][lang] - else: - yield format_info(logger_handler.log, trainer_callback) - - if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): - finish_info = ALERTS["info_finished"][lang] - else: - finish_info = ALERTS["err_failed"][lang] - - yield self.finalize(lang, finish_info) - - def run_eval( + def _parse_eval_args( self, lang: str, model_name: str, @@ -171,7 +170,7 @@ class Runner: finetuning_type: str, quantization_bit: str, template: str, - source_prefix: str, + system_prompt: str, dataset_dir: str, dataset: List[str], max_source_length: int, @@ -179,12 +178,7 @@ class Runner: max_samples: str, batch_size: int, predict: bool - ) -> Generator[str, None, None]: - model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) - if error: - yield error - return - + ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: checkpoint_dir = ",".join( [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] @@ -196,15 +190,15 @@ class Runner: args = dict( stage="sft", - model_name_or_path=model_name_or_path, + model_name_or_path=get_model_path(model_name), do_eval=True, overwrite_cache=True, predict_with_generate=True, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit else None, + quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, - source_prefix=source_prefix, + system_prompt=system_prompt, dataset_dir=dataset_dir, dataset=",".join(dataset), max_source_length=max_source_length, @@ -218,20 +212,72 @@ class Runner: args.pop("do_eval", None) args["do_predict"] = True - run_kwargs = dict(args=args, callbacks=[trainer_callback]) + return lang, model_name, dataset, output_dir, args + + def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, model_name, dataset, _, args = self._parse_train_args(*args) + error = self._initialize(lang, model_name, dataset) + if error: + yield error, gr.update(visible=False) + else: + yield gen_cmd(args), gr.update(visible=False) + + def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, model_name, dataset, _, args = self._parse_eval_args(*args) + error = self._initialize(lang, model_name, dataset) + if error: + yield error, gr.update(visible=False) + else: + yield gen_cmd(args), gr.update(visible=False) + + def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]: + lang, model_name, dataset, output_dir, args = self._parse_train_args(*args) + error = self._initialize(lang, model_name, dataset) + if error: + yield error, gr.update(visible=False) + return + + self.running = True + run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) thread = threading.Thread(target=run_exp, kwargs=run_kwargs) thread.start() while thread.is_alive(): - time.sleep(1) + time.sleep(2) if self.aborted: - yield ALERTS["info_aborting"][lang] + yield ALERTS["info_aborting"][lang], gr.update(visible=False) else: - yield format_info(logger_handler.log, trainer_callback) + yield self.logger_handler.log, update_process_bar(self.trainer_callback) + + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): + finish_info = ALERTS["info_finished"][lang] + else: + finish_info = ALERTS["err_failed"][lang] + + yield self._finalize(lang, finish_info), gr.update(visible=False) + + def run_eval(self, *args) -> Generator[str, None, None]: + lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args) + error = self._initialize(lang, model_name, dataset) + if error: + yield error, gr.update(visible=False) + return + + self.running = True + run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) + thread = threading.Thread(target=run_exp, kwargs=run_kwargs) + thread.start() + + while thread.is_alive(): + time.sleep(2) + if self.aborted: + yield ALERTS["info_aborting"][lang], gr.update(visible=False) + else: + yield self.logger_handler.log, update_process_bar(self.trainer_callback) if os.path.exists(os.path.join(output_dir, "all_results.json")): finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) else: finish_info = ALERTS["err_failed"][lang] - yield self.finalize(lang, finish_info) + yield self._finalize(lang, finish_info), gr.update(visible=False) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 7b667c0f..362fa008 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -15,13 +15,18 @@ if TYPE_CHECKING: from llmtuner.extras.callbacks import LogCallback -def format_info(log: str, callback: "LogCallback") -> str: - info = log - if callback.max_steps: - info += "Running **{:d}/{:d}**: {} < {}\n".format( - callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time - ) - return info +def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: + if not callback.max_steps: + return gr.update(visible=False) + + percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0 + label = "Running {:d}/{:d}: {} < {}".format( + callback.cur_steps, + callback.max_steps, + callback.elapsed_time, + callback.remaining_time + ) + return gr.update(label=label, value=percentage, visible=True) def get_time() -> str: @@ -57,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: return gr.update(interactive=True) +def gen_cmd(args: Dict[str, Any]) -> str: + if args.get("do_train", None): + args["plot_loss"] = True + cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] + for k, v in args.items(): + if v is not None and v != "": + cmd_lines.append(" --{} {} ".format(k, str(v))) + cmd_text = "\\\n".join(cmd_lines) + cmd_text = "```bash\n{}\n```".format(cmd_text) + return cmd_text + + def get_eval_results(path: os.PathLike) -> str: with open(path, "r", encoding="utf-8") as f: result = json.dumps(json.load(f), indent=4)