From abdfa26d06a5549a70865d356452e572dbe50209 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 11 Aug 2023 03:02:53 +0800 Subject: [PATCH] support DPO training (2305.18290) Former-commit-id: 3ec4351cfdaf2aefcc7d13345e19d79874ed61d3 --- README.md | 60 +++++++++++------- README_zh.md | 68 ++++++++++++-------- data/dataset_info.json | 20 ------ data/refgpt_zh_50k_p1.json.REMOVED.git-id | 1 - data/refgpt_zh_50k_p2.json.REMOVED.git-id | 1 - requirements.txt | 3 +- src/api_demo.py | 2 +- src/llmtuner/chat/stream_chat.py | 3 +- src/llmtuner/dsets/preprocess.py | 39 ++++++------ src/llmtuner/extras/callbacks.py | 9 +++ src/llmtuner/extras/constants.py | 4 +- src/llmtuner/extras/misc.py | 8 ++- src/llmtuner/extras/template.py | 73 +++++++++++++++------- src/llmtuner/hparams/data_args.py | 2 +- src/llmtuner/hparams/finetuning_args.py | 24 +++++--- src/llmtuner/hparams/general_args.py | 4 +- src/llmtuner/hparams/generating_args.py | 2 +- src/llmtuner/hparams/model_args.py | 13 ++-- src/llmtuner/tuner/core/adapter.py | 2 +- src/llmtuner/tuner/core/loader.py | 7 +-- src/llmtuner/tuner/core/parser.py | 61 +++++++++++++++--- src/llmtuner/tuner/core/trainer.py | 31 ++++++---- src/llmtuner/tuner/dpo/__init__.py | 1 + src/llmtuner/tuner/dpo/collator.py | 51 +++++++++++++++ src/llmtuner/tuner/dpo/trainer.py | 75 +++++++++++++++++++++++ src/llmtuner/tuner/dpo/workflow.py | 59 ++++++++++++++++++ src/llmtuner/tuner/ppo/trainer.py | 19 +++--- src/llmtuner/tuner/ppo/workflow.py | 25 ++++---- src/llmtuner/tuner/pt/workflow.py | 8 +-- src/llmtuner/tuner/rm/collator.py | 9 ++- src/llmtuner/tuner/sft/trainer.py | 2 +- src/llmtuner/tuner/sft/workflow.py | 15 ++--- src/llmtuner/tuner/tune.py | 22 +++++-- src/llmtuner/webui/runner.py | 2 +- 34 files changed, 513 insertions(+), 212 deletions(-) delete mode 100644 data/refgpt_zh_50k_p1.json.REMOVED.git-id delete mode 100644 data/refgpt_zh_50k_p2.json.REMOVED.git-id create mode 100644 src/llmtuner/tuner/dpo/__init__.py create mode 100644 src/llmtuner/tuner/dpo/collator.py create mode 100644 src/llmtuner/tuner/dpo/trainer.py create mode 100644 src/llmtuner/tuner/dpo/workflow.py diff --git a/README.md b/README.md index cca8fdc3..fcbe0761 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ ## Changelog +[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature). + [23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model. [23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. @@ -54,24 +56,18 @@ | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | -> * **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. -> * For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. +- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. +- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models. ## Supported Training Approaches -- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - - Full-parameter tuning - - Partial-parameter tuning - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) - - Full-parameter tuning - - Partial-parameter tuning - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [RLHF](https://arxiv.org/abs/2203.02155) - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) +| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA | +| ---------------------- | -------------- | ----------------- | ---- | ----- | +| Pre-Training | ✅ | ✅ | ✅ | ✅ | +| Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ | +| Reward Model Training | | | ✅ | ✅ | +| PPO Training | | | ✅ | ✅ | +| DPO Training | ✅ | | ✅ | ✅ | ## Provided Datasets @@ -88,7 +84,6 @@ - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Self-cognition (zh)](data/self_cognition.json) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [RefGPT (zh)](https://github.com/sufengniu/RefGPT) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) @@ -103,7 +98,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) -- For reward modelling: +- For reward modelling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -139,7 +134,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t ### Dependence Installation (optional) ```bash -git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning @@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py Currently the web UI only supports training on **a single GPU**. -### (Continually) Pre-Training +### Pre-Training ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ --output_dir path_to_rm_checkpoint \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ @@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -### PPO Training (RLHF) +### PPO Training ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --plot_loss ``` +### DPO Training + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_your_model \ + --do_train \ + --dataset comparison_gpt4_en \ + --template default \ + --finetuning_type lora \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + ### Distributed Training ```bash diff --git a/README_zh.md b/README_zh.md index d5eca99d..955fefc1 100644 --- a/README_zh.md +++ b/README_zh.md @@ -12,7 +12,9 @@ ## 更新日志 -[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数。 +[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)。 + +[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。 [23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。 @@ -54,41 +56,34 @@ | [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml | | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - | -> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。 -> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等值。 +- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。 +- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。 -## 微调方法 +## 训练方法 -- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - - 全参数微调 - - 部分参数微调 - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [指令监督微调](https://arxiv.org/abs/2109.01652) - - 全参数微调 - - 部分参数微调 - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) -- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155) - - [LoRA](https://arxiv.org/abs/2106.09685) - - [QLoRA](https://arxiv.org/abs/2305.14314) +| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA | +| ---------- | ---------- | ----------- | ---- | ----- | +| 预训练 | ✅ | ✅ | ✅ | ✅ | +| 指令监督微调 | ✅ | ✅ | ✅ | ✅ | +| 奖励模型训练 | | | ✅ | ✅ | +| PPO 训练 | | | ✅ | ✅ | +| DPO 训练 | ✅ | | ✅ | ✅ | ## 数据集 -- 用于二次预训练: +- 用于预训练: - [Wiki Demo (en)](data/wiki_demo.txt) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) -- 用于指令监督微调: +- 用于指令监督微调: - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Self-cognition (zh)](data/self_cognition.json) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - - [RefGPT (zh)](https://github.com/sufengniu/RefGPT) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) @@ -103,7 +98,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) -- 用于奖励模型训练: +- 用于奖励模型或 DPO 训练: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -139,7 +134,6 @@ huggingface-cli login ### 环境搭建(可跳过) ```bash -git lfs install git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git conda create -n llama_etuning python=3.10 conda activate llama_etuning @@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py 目前网页 UI 仅支持**单卡训练**。 -### 二次预训练 +### 预训练 ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --resume_lora_training False \ --checkpoint_dir path_to_sft_checkpoint \ --output_dir path_to_rm_checkpoint \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ --logging_steps 10 \ @@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -### RLHF 训练 +### PPO 训练 ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --plot_loss ``` +### DPO 训练 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage dpo \ + --model_name_or_path path_to_your_model \ + --do_train \ + --dataset comparison_gpt4_zh \ + --template default \ + --finetuning_type lora \ + --resume_lora_training False \ + --checkpoint_dir path_to_sft_checkpoint \ + --output_dir path_to_dpo_checkpoint \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate 1e-5 \ + --num_train_epochs 1.0 \ + --plot_loss \ + --fp16 +``` + ### 多 GPU 分布式训练 ```bash diff --git a/data/dataset_info.json b/data/dataset_info.json index 3a3b4e76..3eaf920e 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -49,26 +49,6 @@ "history": "history" } }, - "refgpt_zh_p1": { - "file_name": "refgpt_zh_50k_p1.json", - "file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752", - "columns": { - "prompt": "instruction", - "query": "input", - "response": "output", - "history": "history" - } - }, - "refgpt_zh_p2": { - "file_name": "refgpt_zh_50k_p2.json", - "file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0", - "columns": { - "prompt": "instruction", - "query": "input", - "response": "output", - "history": "history" - } - }, "lima": { "file_name": "lima.json", "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", diff --git a/data/refgpt_zh_50k_p1.json.REMOVED.git-id b/data/refgpt_zh_50k_p1.json.REMOVED.git-id deleted file mode 100644 index 3e8a9e41..00000000 --- a/data/refgpt_zh_50k_p1.json.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -f967a4f6d04a11308a15524aa9a846a19a8d1e83 \ No newline at end of file diff --git a/data/refgpt_zh_50k_p2.json.REMOVED.git-id b/data/refgpt_zh_50k_p2.json.REMOVED.git-id deleted file mode 100644 index a6525b27..00000000 --- a/data/refgpt_zh_50k_p2.json.REMOVED.git-id +++ /dev/null @@ -1 +0,0 @@ -0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c71b6c9c..fb5fa72f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.29.1 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 -trl>=0.4.7 +trl>=0.5.0 scipy sentencepiece tiktoken @@ -16,4 +16,3 @@ pydantic==1.10.11 fastapi==0.95.1 sse-starlette matplotlib -huggingface_hub \ No newline at end of file diff --git a/src/api_demo.py b/src/api_demo.py index c0ca9760..777f9dcf 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -7,7 +7,7 @@ def main(): chat_model = ChatModel() app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) - # Visit http://localhost:8000/docs for document. + print("Visit http://localhost:8000/docs for API document.") if __name__ == "__main__": diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 79d3b92d..8220c0b3 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -18,7 +18,6 @@ class ChatModel: self.model = self.model.eval() # change to eval mode self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.source_prefix = data_args.source_prefix - self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words) self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen) def process_args( @@ -53,7 +52,7 @@ class ChatModel: top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], logits_processor=get_logits_processor(), - stopping_criteria=get_stopping_criteria(self.stop_ids) + stopping_criteria=get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) )) if max_length: diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 534d77b5..4efbcbb6 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -46,7 +46,6 @@ def preprocess_dataset( k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } - result["labels"] = result["input_ids"].copy() return result def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: @@ -95,24 +94,22 @@ def preprocess_dataset( return model_inputs def preprocess_pairwise_dataset(examples): - # build input pairs with format ` X Y1 ` and ` X Y2 ` - model_inputs = {"accept_ids": [], "reject_ids": []} + # build input pairs with format ` X`, `Y1 ` and `Y2 ` + model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} for query, response, history, prefix in construct_example(examples): - source_ids, accept_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) - source_ids, reject_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) + prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, prefix) + _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, prefix) - if len(source_ids) > data_args.max_source_length: - source_ids = source_ids[:data_args.max_source_length] - if len(accept_ids) > data_args.max_target_length: - accept_ids = accept_ids[:data_args.max_target_length] - if len(reject_ids) > data_args.max_target_length: - reject_ids = reject_ids[:data_args.max_target_length] + if len(prompt_ids) > data_args.max_source_length: + prompt_ids = prompt_ids[:data_args.max_source_length] + if len(chosen_ids) > data_args.max_target_length: + chosen_ids = chosen_ids[:data_args.max_target_length] + if len(rejected_ids) > data_args.max_target_length: + rejected_ids = rejected_ids[:data_args.max_target_length] - accept_ids = source_ids + accept_ids - reject_ids = source_ids + reject_ids - - model_inputs["accept_ids"].append(accept_ids) - model_inputs["reject_ids"].append(reject_ids) + model_inputs["prompt_ids"].append(prompt_ids) + model_inputs["chosen_ids"].append(chosen_ids) + model_inputs["rejected_ids"].append(rejected_ids) return model_inputs def print_supervised_dataset_example(example): @@ -124,10 +121,12 @@ def preprocess_dataset( ], skip_special_tokens=False))) def print_pairwise_dataset_example(example): - print("accept_ids:\n{}".format(example["accept_ids"])) - print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) - print("reject_ids:\n{}".format(example["reject_ids"])) - print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) + print("prompt_ids:\n{}".format(example["prompt_ids"])) + print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) + print("chosen_ids:\n{}".format(example["chosen_ids"])) + print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) + print("rejected_ids:\n{}".format(example["rejected_ids"])) + print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) def print_unsupervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index d325b0a8..61deae25 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -7,10 +7,16 @@ from datetime import timedelta from transformers import TrainerCallback from transformers.trainer_utils import has_length +from llmtuner.extras.constants import LOG_FILE_NAME +from llmtuner.extras.logging import get_logger + if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl +logger = get_logger(__name__) + + class LogCallback(TrainerCallback): def __init__(self, runner=None): @@ -38,6 +44,9 @@ class LogCallback(TrainerCallback): self.in_training = True self.start_time = time.time() self.max_steps = state.max_steps + if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)): + logger.warning("Previous log file in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 6f6dbdd7..dc1cb2e7 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -1,10 +1,12 @@ IGNORE_INDEX = -100 +LOG_FILE_NAME = "trainer_log.jsonl" + VALUE_HEAD_FILE_NAME = "value_head.bin" FINETUNING_ARGS_NAME = "finetuning_args.json" -LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings +LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] METHODS = ["full", "freeze", "lora"] diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index e1fbb156..ee918fbb 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,7 +1,11 @@ import torch from typing import TYPE_CHECKING, List, Optional, Tuple - -from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList +from transformers import ( + LogitsProcessor, + LogitsProcessorList, + StoppingCriteria, + StoppingCriteriaList +) from llmtuner.extras.constants import LAYERNORM_NAMES diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index d94d5b37..c3388ab3 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -61,7 +61,7 @@ class Template: prefix: Optional[str] = None ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]: r""" - Aligns inputs to a special format. + Aligns inputs to the standard format. """ prefix = [prefix] if prefix else self.prefix # use prefix if provided history = history if (history and self.use_history) else [] @@ -92,28 +92,32 @@ class Template: ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + sep + query resp + eos + Turn t: sep + bos + query resp + eos """ bos_ids, eos_ids = self._get_special_ids(tokenizer) sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) encoded_pairs = [] for turn_idx, (query, resp) in enumerate(history): - if turn_idx != 0: - prefix_ids = sep_ids - elif prefix: - prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids + if turn_idx == 0: + if prefix: # has prefix + prefix_ids = bos_ids + self._convert_inputs_to_ids(tokenizer, context=prefix) + sep_ids + else: + prefix_ids = bos_ids else: - prefix_ids = [] + prefix_ids = sep_ids + bos_ids - query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) + query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) - encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids)) + encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) return encoded_pairs def _convert_inputs_to_ids( self, tokenizer: "PreTrainedTokenizer", context: List[Union[str, Dict[str, str]]], - query: Optional[str] = "" + query: Optional[str] = "", + idx: Optional[str] = "" ) -> List[int]: r""" Converts context to token ids. @@ -127,6 +131,7 @@ class Template: for elem in context: if isinstance(elem, str): elem = elem.replace("{{query}}", query, 1) + elem = elem.replace("{{idx}}", idx, 1) token_ids = token_ids + tokenizer.encode(elem, **kwargs) elif isinstance(elem, dict): token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] @@ -146,10 +151,12 @@ class Llama2Template(Template): ) -> List[Tuple[List[int], List[int]]]: r""" Encodes formatted inputs to pairs of token ids. + Turn 0: bos + prefix + query resp + eos + Turn t: bos + query resp + eos """ bos_ids, eos_ids = self._get_special_ids(tokenizer) encoded_pairs = [] - assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str." + assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single string." for turn_idx, (query, resp) in enumerate(history): if turn_idx == 0: # llama2 template has not sep_ids query = prefix[0] + query @@ -187,11 +194,12 @@ def get_template_and_fix_tokenizer( template = templates.get(name, None) assert template is not None, "Template {} does not exist.".format(name) - if tokenizer.eos_token_id is None: # inplace method - if len(template.stop_words): - tokenizer.eos_token = template.stop_words[0] - else: - tokenizer.eos_token = "<|endoftext|>" + if len(template.stop_words): # inplace method + tokenizer.eos_token = template.stop_words[0] + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if tokenizer.eos_token_id is None: + tokenizer.eos_token = "<|endoftext|>" logger.info("Add eos token: {}".format(tokenizer.eos_token)) if tokenizer.pad_token_id is None: @@ -422,12 +430,13 @@ register_template( name="baichuan", prefix=[], prompt=[ - {"token": ""}, - "{{query}}", - {"token": ""} + {"token": ""}, # user token (a little difference in position) + "{{query}}" ], sep=[], - stop_words=[], + stop_words=[ + "" # assistant token + ], use_history=True ) @@ -440,7 +449,8 @@ register_template( name="starchat", prefix=[ {"token": "<|system|>"}, - "\n" + "\n", + {"token": "<|end|>"} ], prompt=[ {"token": "<|user|>"}, @@ -466,7 +476,8 @@ register_template( name="chatml", prefix=[ {"token": "<|im_start|>"}, - "system\nYou are a helpful assistant." + "system\nYou are a helpful assistant.", + {"token": "<|im_end|>"} ], prompt=[ {"token": "<|im_start|>"}, @@ -484,3 +495,23 @@ register_template( ], use_history=True ) + + +r""" +Supports: https://huggingface.co/THUDM/chatglm2-6b +""" +register_template( + name="chatglm2", + prefix=[ + {"token": "[gMASK]"}, + {"token": "sop"} + ], + prompt=[ + "[Round {{idx}}]\n\n问:{{query}}\n\n答:" + ], + sep=[ + "\n\n" + ], + stop_words=[], + use_history=True +) diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index de470ae2..7d1c982c 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -24,7 +24,7 @@ class DatasetAttr: @dataclass class DataArguments: - """ + r""" Arguments pertaining to what data we are going to input our model for training and evaluation. """ template: str = field( diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 277602ae..c4713c5e 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field @dataclass class FinetuningArguments: - """ + r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( @@ -14,7 +14,7 @@ class FinetuningArguments: ) num_hidden_layers: Optional[int] = field( default=32, - metadata={"help": "Number of decoder blocks in the model. \ + metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \ @@ -25,16 +25,16 @@ class FinetuningArguments: ) num_layer_trainable: Optional[int] = field( default=3, - metadata={"help": "Number of trainable layers for Freeze fine-tuning."} + metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} ) name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( default="mlp", - metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ - LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ + metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ Baichuan choices: [\"mlp\", \"self_attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \ - InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} ) lora_rank: Optional[int] = field( default=8, @@ -51,11 +51,15 @@ class FinetuningArguments: lora_target: Optional[str] = field( default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ - LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ - InternLM, XVERSE choices: the same as LLaMA."} + LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} + ) + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."} ) def __post_init__(self): @@ -72,14 +76,14 @@ class FinetuningArguments: assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." def save_to_json(self, json_path: str): - """Saves the content of this instance in JSON format inside `json_path`.""" + r"""Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" with open(json_path, "w", encoding="utf-8") as f: f.write(json_string) @classmethod def load_from_json(cls, json_path: str): - """Creates an instance from the content of `json_path`.""" + r"""Creates an instance from the content of `json_path`.""" with open(json_path, "r", encoding="utf-8") as f: text = f.read() return cls(**json.loads(text)) diff --git a/src/llmtuner/hparams/general_args.py b/src/llmtuner/hparams/general_args.py index 397d3019..c0c1a0de 100644 --- a/src/llmtuner/hparams/general_args.py +++ b/src/llmtuner/hparams/general_args.py @@ -4,10 +4,10 @@ from dataclasses import dataclass, field @dataclass class GeneralArguments: - """ + r""" Arguments pertaining to which stage we are going to perform. """ - stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( default="sft", metadata={"help": "Which stage will be performed in training."} ) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index e25ff4b9..f8b935fb 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field @dataclass class GeneratingArguments: - """ + r""" Arguments pertaining to specify the decoding parameters. """ do_sample: Optional[bool] = field( diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 94e882fe..6c8491da 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -1,12 +1,11 @@ import torch from typing import Literal, Optional from dataclasses import dataclass, field -from huggingface_hub.hf_api import HfFolder @dataclass class ModelArguments: - """ + r""" Arguments pertaining to which model/config/tokenizer we are going to fine-tune. """ model_name_or_path: str = field( @@ -64,12 +63,11 @@ class ModelArguments: default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} ) - hf_hub_token : Optional[str] = field( + hf_auth_token: Optional[str] = field( default=None, - metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + metadata={"help": "Auth token to log in with Hugging Face Hub."} ) - def __post_init__(self): if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] @@ -77,5 +75,6 @@ class ModelArguments: if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." - if self.use_auth_token == True and self.hf_hub_token != None: - HfFolder.save_token(self.hf_hub_token) + if self.use_auth_token == True and self.hf_auth_token is not None: + from huggingface_hub.hf_api import HfFolder # lazy load + HfFolder.save_token(self.hf_auth_token) diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 4afad13a..a8ac5a84 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -39,7 +39,7 @@ def init_adapter( if finetuning_args.finetuning_type == "none" and is_trainable: raise ValueError("You cannot use finetuning_type=none while training.") - if finetuning_args.finetuning_type == "full": + if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") model = model.float() diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index c06eabfa..39bec1d8 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -34,7 +34,7 @@ check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") -require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") +require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0") def load_model_and_tokenizer( @@ -52,9 +52,6 @@ def load_model_and_tokenizer( logger.warning("Checkpoint is not found at evaluation, load the original model.") finetuning_args = FinetuningArguments(finetuning_type="none") - assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ - "RM and PPO training can only be performed with the LoRA method." - config_kwargs = { "trust_remote_code": True, "cache_dir": model_args.cache_dir, @@ -132,8 +129,6 @@ def load_model_and_tokenizer( }) if stage == "ppo": # load reward model - assert is_trainable, "PPO stage cannot be performed at evaluation." - assert model_args.reward_model is not None, "Reward model is necessary for PPO training." logger.info("Load reward model from {}".format(model_args.reward_model)) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 692f9b13..c5bdbe16 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -19,7 +19,7 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): +def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): @@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) def parse_train_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments +]: parser = HfArgumentParser(( - ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments )) return _parse_args(parser, args) def parse_infer_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: parser = HfArgumentParser(( - ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments )) return _parse_args(parser, args) def get_train_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: - model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) +) -> Tuple[ + ModelArguments, + DataArguments, + Seq2SeqTrainingArguments, + FinetuningArguments, + GeneratingArguments, + GeneralArguments +]: + model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args) # Setup logging if training_args.should_log: @@ -68,7 +95,7 @@ def get_train_args( data_args.init_for_training() if general_args.stage != "sft" and training_args.predict_with_generate: - raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") if training_args.do_train and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True while training.") @@ -76,6 +103,15 @@ def get_train_args( if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") + if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": + raise ValueError("RM and PPO training can only be performed with the LoRA method.") + + if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: + raise ValueError("PPO and DPO stage can only be performed at training.") + + if general_args.stage == "ppo" and model_args.reward_model is None: + raise ValueError("Reward model is necessary for PPO training.") + if training_args.max_steps == -1 and data_args.streaming: raise ValueError("Please specify `max_steps` in streaming mode.") @@ -133,12 +169,17 @@ def get_train_args( # Set seed before initializing model. transformers.set_seed(training_args.seed) - return model_args, data_args, training_args, finetuning_args, general_args + return model_args, data_args, training_args, finetuning_args, generating_args, general_args def get_infer_args( args: Optional[Dict[str, Any]] = None -) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: +) -> Tuple[ + ModelArguments, + DataArguments, + FinetuningArguments, + GeneratingArguments +]: model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index ae80f32f..058bb740 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState from llmtuner.hparams import FinetuningArguments logger = get_logger(__name__) -class PeftTrainer(Seq2SeqTrainer): +class PeftModelMixin: r""" - Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead. """ - def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): - super().__init__(**kwargs) - self.finetuning_args = finetuning_args - self._remove_log() - - def _remove_log(self): - if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): - logger.warning("Previous log file in this folder will be deleted.") - os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) + def __init__(self) -> None: # for type checking + self.model: PreTrainedModel = None + self.tokenizer: "PreTrainedTokenizer" = None + self.args: "Seq2SeqTrainingArguments" = None + self.finetuning_args: "FinetuningArguments" = None + self.state: "TrainerState" = None + raise AssertionError("Mixin should not be initialized.") def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: r""" @@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer): model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) else: # freeze/full-tuning load_trainable_params(model, self.state.best_model_checkpoint) + + +class PeftTrainer(PeftModelMixin, Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. + """ + + def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): + Seq2SeqTrainer.__init__(self, **kwargs) + self.finetuning_args = finetuning_args diff --git a/src/llmtuner/tuner/dpo/__init__.py b/src/llmtuner/tuner/dpo/__init__.py new file mode 100644 index 00000000..f2b5cfb5 --- /dev/null +++ b/src/llmtuner/tuner/dpo/__init__.py @@ -0,0 +1 @@ +from llmtuner.tuner.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/dpo/collator.py b/src/llmtuner/tuner/dpo/collator.py new file mode 100644 index 00000000..2f0f4bdc --- /dev/null +++ b/src/llmtuner/tuner/dpo/collator.py @@ -0,0 +1,51 @@ +import torch +from dataclasses import dataclass +from typing import Any, Dict, List, Sequence, Tuple +from transformers import DataCollatorForSeq2Seq + + +@dataclass +class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: + padded_labels = [] + for feature, (prompt_len, answer_len) in zip(batch, positions): + if self.tokenizer.padding_side == "left": + start, end = feature.size(0) - answer_len, feature.size(0) + else: + start, end = prompt_len, answer_len + padded_tensor = self.label_pad_token_id * torch.ones_like(feature) + padded_tensor[start:end] = feature[start:end] + padded_labels.append(padded_tensor) + return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + concatenated_features = [] + label_positions = [] + for key in ("chosen_ids", "rejected_ids"): + for feature in features: + prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) + concatenated_features.append({ + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (prompt_len + answer_len) + }) + label_positions.append((prompt_len, answer_len)) + + batch = self.tokenizer.pad( + concatenated_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) + return batch diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py new file mode 100644 index 00000000..a94642c1 --- /dev/null +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -0,0 +1,75 @@ +import torch +from collections import defaultdict +from peft import PeftModel +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from transformers import Trainer +from trl import DPOTrainer + +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.tuner.core.trainer import PeftModelMixin + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from llmtuner.hparams import FinetuningArguments, GeneratingArguments + + +class DPOPeftTrainer(PeftModelMixin, DPOTrainer): + + def __init__( + self, + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, + **kwargs + ): + self.finetuning_args = finetuning_args + self.generating_args = generating_args + self.ref_model = ref_model + self.use_dpo_data_collator = True # hack to avoid warning + self.label_pad_token_id = IGNORE_INDEX + self.padding_value = 0 + self.beta = finetuning_args.dpo_beta + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + Trainer.__init__(self, **kwargs) + if ref_model is not None: + if hasattr(self, "accelerator"): + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + else: + raise AttributeError("Please update `transformers`.") + + def concatenated_forward( + self, + model: Optional[torch.nn.Module] = None, + batch: Optional[Dict[str, torch.Tensor]] = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) + if not torch.is_grad_enabled(): + unwrapped_model.gradient_checkpointing_disable() + + if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model + with unwrapped_model.disable_adapter(): + all_logits: torch.Tensor = self.model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + return_dict=True + ).logits.to(torch.float32) + else: + all_logits: torch.Tensor = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + return_dict=True + ).logits.to(torch.float32) + + if not torch.is_grad_enabled(): + unwrapped_model.gradient_checkpointing_enable() + + all_logps = self._get_batch_logps( + all_logits, + batch["labels"], + average_log_prob=False + ) + batch_size = batch["input_ids"].size(0) // 2 + chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) + chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py new file mode 100644 index 00000000..350d64a7 --- /dev/null +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -0,0 +1,59 @@ +# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py + +from copy import deepcopy +from peft import PeftModel +from typing import TYPE_CHECKING, Optional, List + +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.ploting import plot_loss +from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding +from llmtuner.tuner.dpo.trainer import DPOPeftTrainer + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + + +def run_dpo( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None +): + dataset = get_dataset(model_args, data_args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") + dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") + data_collator = DPODataCollatorWithPadding( + tokenizer=tokenizer, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + ) + + training_args.remove_unused_columns = False # important for pairwise dataset + ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None + + # Initialize our Trainer + trainer = DPOPeftTrainer( + finetuning_args=finetuning_args, + generating_args=generating_args, + ref_model=ref_model, + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **split_dataset(dataset, data_args, training_args) + ) + + # Training + if training_args.do_train: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + trainer.save_model() + if trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 3392aa4d..26a8db99 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -10,7 +10,7 @@ from trl import PPOTrainer from trl.core import LengthSampler from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor +from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model @@ -18,7 +18,7 @@ if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.callbacks import LogCallback - from llmtuner.hparams import FinetuningArguments + from llmtuner.hparams import FinetuningArguments, GeneratingArguments logger = get_logger(__name__) @@ -33,16 +33,17 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self, training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: List["LogCallback"], **kwargs ): PPOTrainer.__init__(self, **kwargs) self.args = training_args self.finetuning_args = finetuning_args + self.generating_args = generating_args self.log_callback = callbacks[0] self.state = TrainerState() self.control = TrainerControl() - self._remove_log() def ppo_train(self, max_target_length: int) -> None: r""" @@ -72,14 +73,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") # Keyword arguments for `model.generate` - gen_kwargs = { - "top_k": 0.0, - "top_p": 1.0, - "do_sample": True, - "pad_token_id": self.tokenizer.pad_token_id, - "eos_token_id": self.tokenizer.eos_token_id, - "logits_processor": get_logits_processor() - } + gen_kwargs = self.generating_args.to_dict() + gen_kwargs["logits_processor"] = get_logits_processor() + gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids) + length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index aa372671..6734ab78 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -1,11 +1,9 @@ -# Inspired by: -# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py +# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py import math -from typing import TYPE_CHECKING from trl import PPOConfig from torch.optim import AdamW -from typing import Optional, List +from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler @@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments def run_ppo( @@ -24,6 +22,7 @@ def run_ppo( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) @@ -42,8 +41,9 @@ def run_ppo( ) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) - total_train_batch_size = \ + total_train_batch_size = ( training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size + ) num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) lr_scheduler = get_scheduler( training_args.lr_scheduler_type, @@ -56,6 +56,7 @@ def run_ppo( ppo_trainer = PPOPeftTrainer( training_args=training_args, finetuning_args=finetuning_args, + generating_args=generating_args, callbacks=callbacks, config=ppo_config, model=model, @@ -67,8 +68,10 @@ def run_ppo( lr_scheduler=lr_scheduler ) - ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) - ppo_trainer.save_model() - ppo_trainer.save_state() # must be after save_model - if ppo_trainer.is_world_process_zero() and model_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "reward"]) + # Training + if training_args.do_train: + ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) + ppo_trainer.save_model() + ppo_trainer.save_state() # must be called after save_model to have a folder + if ppo_trainer.is_world_process_zero() and model_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index b4ea148b..f7bf6448 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -2,10 +2,9 @@ import math from typing import TYPE_CHECKING, Optional, List -from transformers import DataCollatorForSeq2Seq +from transformers import DataCollatorForLanguageModeling from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core.trainer import PeftTrainer @@ -25,10 +24,7 @@ def run_pt( dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") - data_collator = DataCollatorForSeq2Seq( - tokenizer=tokenizer, - label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id - ) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer trainer = PeftTrainer( diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/tuner/rm/collator.py index c0da0579..161f003d 100644 --- a/src/llmtuner/tuner/rm/collator.py +++ b/src/llmtuner/tuner/rm/collator.py @@ -1,8 +1,10 @@ import torch +from dataclasses import dataclass from typing import Any, Dict, Sequence from transformers import DataCollatorWithPadding +@dataclass class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): r""" Data collator for pairwise data. @@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): the last n examples represent rejected examples. """ features = [ - {"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} - for key in ("accept_ids", "reject_ids") for feature in features + { + "input_ids": feature["prompt_ids"] + feature[key], + "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) + } + for key in ("chosen_ids", "rejected_ids") for feature in features ] return super().__call__(features) diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/tuner/sft/trainer.py index 21739ac1..6243928f 100644 --- a/src/llmtuner/tuner/sft/trainer.py +++ b/src/llmtuner/tuner/sft/trainer.py @@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer): padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding - return padded_tensor.contiguous() + return padded_tensor.contiguous() # in contiguous memory def save_predictions( self, diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index a5cd2cd3..10f7aafb 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics @@ -13,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments def run_sft( @@ -21,6 +21,7 @@ def run_sft( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) @@ -50,13 +51,9 @@ def run_sft( ) # Keyword arguments for `model.generate` - gen_kwargs = { - "do_sample": True, - "top_p": 0.7, - "max_new_tokens": data_args.max_target_length + 1, - "temperature": 0.95, - "logits_processor": get_logits_processor() - } + gen_kwargs = generating_args.to_dict() + gen_kwargs["logits_processor"] = get_logits_processor() + gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids) # Training if training_args.do_train: diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py index 99f5d2a9..dee49ef4 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/tuner/tune.py @@ -1,35 +1,47 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback +from llmtuner.extras.logging import get_logger from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.tuner.pt import run_pt from llmtuner.tuner.sft import run_sft from llmtuner.tuner.rm import run_rm from llmtuner.tuner.ppo import run_ppo +from llmtuner.tuner.dpo import run_dpo if TYPE_CHECKING: from transformers import TrainerCallback +logger = get_logger(__name__) + + def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): - model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args) - callbacks = [LogCallback()] if callbacks is None else callbacks + model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args) + callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()] if general_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif general_args.stage == "sft": - run_sft(model_args, data_args, training_args, finetuning_args, callbacks) + run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif general_args.stage == "rm": run_rm(model_args, data_args, training_args, finetuning_args, callbacks) elif general_args.stage == "ppo": - run_ppo(model_args, data_args, training_args, finetuning_args, callbacks) + run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif general_args.stage == "dpo": + run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) + else: + raise ValueError("Unknown task.") def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): model_args, _, training_args, finetuning_args, _ = get_train_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) - tokenizer.save_pretrained(training_args.output_dir) + try: + tokenizer.save_pretrained(training_args.output_dir) + except: + logger.warning("Cannot save tokenizer, please copy the files manually.") if __name__ == "__main__": diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index fddc4070..36a8bf53 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -4,7 +4,7 @@ import threading import time import transformers from transformers.trainer import TRAINING_ARGS_NAME -from typing import Generator, List, Optional, Tuple +from typing import Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE