Merge pull request #3799 from hiyouga/dev

improve KTO impl, replace datasets

Former-commit-id: b4cc207855aa1dbb120f7999165e176e649af338
This commit is contained in:
hoshi-hiyouga 2024-05-18 03:49:13 +08:00 committed by GitHub
commit 31daec2749
53 changed files with 448 additions and 330 deletions

View File

@ -45,7 +45,7 @@ Choose your path:
## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
@ -69,14 +69,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
<details><summary>Full Changelog</summary>
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
@ -188,6 +190,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## Provided Datasets
@ -208,12 +211,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
<details><summary>Supervised fine-tuning datasets</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -222,7 +225,6 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@ -235,15 +237,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -259,13 +262,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
<details><summary>Preference datasets</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details>

View File

@ -45,7 +45,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
@ -69,14 +69,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
<details><summary>展开日志</summary>
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
@ -188,6 +190,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
## 数据集
@ -208,12 +211,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>指令微调数据集</summary>
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Identity (en&zh)](data/identity.json)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@ -222,7 +225,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
@ -235,15 +237,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@ -259,13 +262,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
<details><summary>偏好数据集</summary>
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
</details>

View File

@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
"messages": "the column name in the dataset containing the messages. (default: conversations)",
"system": "the column name in the dataset containing the system prompts. (default: None)",
"tools": "the column name in the dataset containing the tool description. (default: None)",
"images": "the column name in the dataset containing the image inputs. (default: None)"
"images": "the column name in the dataset containing the image inputs. (default: None)",
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
},
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
```json
[
{
"instruction": "user instruction (required)",
"input": "user input (optional)",
"instruction": "human instruction (required)",
"input": "human input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
["user instruction in the first round (optional)", "model response in the first round (optional)"],
["user instruction in the second round (optional)", "model response in the second round (optional)"]
["human instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
}
```
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
```json
[
{
"instruction": "user instruction",
"input": "user input",
"output": [
"chosen answer",
"rejected answer"
]
"instruction": "human instruction",
"input": "human input",
"chosen": "chosen answer",
"rejected": "rejected answer"
}
]
```
@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
"conversations": [
{
"from": "human",
"value": "user instruction"
"value": "human instruction"
},
{
"from": "gpt",
@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
},
{
"role": "user",
"content": "user instruction"
"content": "human instruction"
},
{
"role": "assistant",

View File

@ -19,7 +19,10 @@
"messages": "数据集代表消息列表的表头名称默认conversations",
"system": "数据集代表系统提示的表头名称默认None",
"tools": "数据集代表工具描述的表头名称默认None",
"images": "数据集代表图像输入的表头名称默认None"
"images": "数据集代表图像输入的表头名称默认None",
"chosen": "数据集代表更优回复的表头名称默认None",
"rejected": "数据集代表更差回复的表头名称默认None",
"kto_tag": "数据集代表 KTO 标签的表头名称默认None"
},
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
@ -42,8 +45,8 @@
```json
[
{
"instruction": "用户指令(必填)",
"input": "用户输入(选填)",
"instruction": "人类指令(必填)",
"input": "人类输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
@ -69,7 +72,7 @@
}
```
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答。
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery``response` 列对应的内容为模型回答。
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
@ -98,12 +101,10 @@
```json
[
{
"instruction": "用户指令",
"input": "用户输入",
"output": [
"优质回答",
"劣质回答"
]
"instruction": "人类指令",
"input": "人类输入",
"chosen": "优质回答",
"rejected": "劣质回答"
}
]
```
@ -117,7 +118,8 @@
"columns": {
"prompt": "instruction",
"query": "input",
"response": "output",
"chosen": "chosen",
"rejected": "rejected"
}
}
```
@ -132,7 +134,7 @@
"conversations": [
{
"from": "human",
"value": "用户指令"
"value": "人类指令"
},
{
"from": "gpt",
@ -165,7 +167,7 @@
}
```
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
我们同样支持 **openai** 格式的数据集:
@ -179,7 +181,7 @@
},
{
"role": "user",
"content": "用户指令"
"content": "人类指令"
},
{
"role": "assistant",

View File

@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@ -1 +0,0 @@
a97cf9475291591843976554878568e046d8a46d

View File

@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

View File

@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@ -1,37 +0,0 @@
import json
from typing import Any, Dict, Generator, List, Tuple
import datasets
_DESCRIPTION = "An example of dataset."
_CITATION = ""
_HOMEPAGE = ""
_LICENSE = ""
_URL = "examples.json"
class ExampleDataset(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
def _info(self) -> datasets.DatasetInfo:
features = datasets.Features(
{
"instruction": datasets.Value("string"),
"input": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
file_path = dl_manager.download(_URL)
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
for key, example in enumerate(example_dataset):
yield key, example

View File

@ -1 +0,0 @@
4748dff00d1dc42768a5b6cc772143c313017812

View File

@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
break
prompt = prompt[:human_idx]
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
key += 1

View File

@ -1 +0,0 @@
736bcedea2b24a1414765c6d69cbdafaea839f3c

30
data/wiki_demo.txt Normal file

File diff suppressed because one or more lines are too long

View File

@ -1 +0,0 @@
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb

View File

@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
```
#### KTO Training
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
```
#### ORPO Training
```bash

View File

@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
```
#### KTO 训练
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
```
#### ORPO 训练
```bash

View File

@ -11,7 +11,7 @@ badam_switch_interval: 50
badam_verbose: 2
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ lora_target: q_proj,v_proj
ddp_timeout: 180000000
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ galore_rank: 128
galore_scale: 2.0
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -10,7 +10,7 @@ freeze_trainable_modules: all
use_llama_pro: true
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
loraplus_lr_ratio: 16.0
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: full
mixture_of_depths: convert
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -7,7 +7,7 @@ do_predict: true
finetuning_type: full
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 50

View File

@ -11,7 +11,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -11,7 +11,7 @@ lora_target: q_proj,v_proj
ddp_timeout: 180000000
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -12,7 +12,7 @@ ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
dpo_ftx: 1.0
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
@ -26,7 +26,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.00001
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1

View File

@ -0,0 +1,39 @@
### model
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
### method
stage: kto
do_train: true
finetuning_type: lora
lora_target: q_proj,v_proj
kto_ftx: 0.1
### dataset
dataset: kto_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/llama3-8b/lora/kto
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1
fp16: true
### eval
val_size: 0.1
per_device_eval_batch_size: 1
evaluation_strategy: steps
eval_steps: 500

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000
@ -25,7 +25,7 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 0.00001
learning_rate: 0.000005
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_steps: 0.1

View File

@ -9,7 +9,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ do_predict: true
finetuning_type: lora
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 50

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: orca_rlhf
dataset: dpo_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -9,7 +9,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -8,7 +8,7 @@ finetuning_type: lora
lora_target: q_proj,v_proj
### dataset
dataset: identity,alpaca_gpt4_en
dataset: identity,alpaca_en_demo
template: llama3
cutoff_len: 1024
max_samples: 1000

View File

@ -1,12 +1,12 @@
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .loader import get_dataset
from .template import Template, get_template_and_fix_tokenizer, templates
from .utils import Role, split_dataset
__all__ = [
"PairwiseDataCollatorWithPadding",
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"get_dataset",
"Template",
"get_template_and_fix_tokenizer",

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from ..extras.logging import get_logger
from .utils import Role
@ -14,7 +15,13 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
logger = get_logger(__name__)
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs = []
if dataset_attr.load_from in ["script", "file"]:
for image in images:
@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
def convert_alpaca(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
r"""
Converts alpaca format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
for i in range(len(examples[dataset_attr.prompt])):
prompt = []
@ -45,23 +55,33 @@ def convert_alpaca(
if dataset_attr.query and examples[dataset_attr.query][i]:
content.append(examples[dataset_attr.query][i])
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else:
if examples[dataset_attr.kto_tag]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], str)
and isinstance(examples[dataset_attr.rejected][i], str)
): # pairwise example
response = [
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
else: # unsupervised
response = []
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
return outputs
@ -69,6 +89,9 @@ def convert_alpaca(
def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
r"""
Converts sharegpt format dataset to the standard format.
"""
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
tag_mapping = {
@ -88,21 +111,62 @@ def convert_sharegpt(
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0:
continue
aligned_messages = []
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
raise ValueError("Invalid role tag in {}.".format(messages))
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
outputs["prompt"].append(aligned_messages[:-1])
outputs["response"].append(aligned_messages[-1:])
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning("Invalid message count in {}.".format(messages))
broken_data = True
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if examples[dataset_attr.kto_tag][i]:
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
else:
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
elif (
dataset_attr.ranking
and isinstance(examples[dataset_attr.chosen][i], dict)
and isinstance(examples[dataset_attr.rejected][i], dict)
): # pairwise example
chosen = examples[dataset_attr.chosen][i]
rejected = examples[dataset_attr.rejected][i]
if (
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning("Invalid role tag in {}.".format(messages))
broken_data = True
prompt = aligned_messages
response = [
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
]
else: # normal example
prompt = aligned_messages[:-1]
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
continue
outputs["prompt"].append(prompt)
outputs["response"].append(response)
outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
@ -138,7 +202,6 @@ def align_dataset(
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
"images": [{"_type": "Image"}],
"tag": {"dtype": "bool", "_type": "Value"},
}
)
kwargs = {}

View File

@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch
@dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for KTO data.
"""
def __call__(self, features, return_tensors=None):
concatenated_features = []
kl_concatenated_features = []
tags = []
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
target_features = []
kl_features = []
kto_tags = []
for feature in features:
concatenated_features.append(
target_features.append(
{
"input_ids": feature["input_ids"],
"attention_mask": feature["attention_mask"],
"labels": feature["labels"],
}
)
kl_concatenated_features.append(
kl_features.append(
{
"input_ids": feature["kl_input_ids"],
"attention_mask": feature["kl_attention_mask"],
"labels": feature["kl_labels"],
}
)
tags.append(feature["tag"])
batch = super().__call__(concatenated_features)
kl_batch = super().__call__(kl_concatenated_features)
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
kto_tags.append(feature["kto_tags"])
batch = super().__call__(target_features)
kl_batch = super().__call__(kl_features)
batch["kl_input_ids"] = kl_batch["input_ids"]
batch["kl_attention_mask"] = kl_batch["attention_mask"]
batch["kl_labels"] = kl_batch["labels"]
batch["tag"] = torch.tensor(tags)
return batch
batch["kto_tags"] = torch.tensor(kto_tags)
return batch

View File

@ -57,7 +57,7 @@ def load_single_dataset(
data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else:
raise ValueError("File not found.")
raise ValueError("File {} not found.".format(local_path))
if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.")
@ -116,7 +116,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:

View File

@ -25,21 +25,22 @@ class DatasetAttr:
folder: Optional[str] = None
ranking: bool = False
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" columns """
""" common columns """
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
tag: Optional[bool] = None
""" columns for the alpaca format """
""" rlhf columns """
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
chosen: Optional[str] = "chosen"
rejected: Optional[str] = "rejected"
history: Optional[str] = None
""" columns for the sharegpt format """
""" sharegpt columns """
messages: Optional[str] = "conversations"
tools: Optional[str] = None
""" tags for the sharegpt format """
""" sharegpt tags """
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"
@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]:
column_names = ["system", "images", "tag"]
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca":
column_names.extend(["prompt", "query", "response", "history"])
else:
column_names.extend(["messages", "tools"])
column_names.extend(["messages"])
for column_name in column_names:
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])

View File

@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["tag"].append(examples["tag"])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
examples['kl_response'] = examples['response'][::-1]
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
messages = examples["prompt"][i] + examples["response"][i]
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
input_ids, labels = [], []
kl_input_ids, kl_labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
source_mask = source_ids
elif turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
kl_input_ids += source_ids + target_ids
kl_labels += source_mask + target_ids
if template.efficient_eos:
kl_input_ids += [tokenizer.eos_token_id]
kl_labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["tag"].append(examples["tag"][i])
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
if desirable == 0 or undesirable == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def preprocess_packed_supervised_dataset(
examples: Dict[str, List[Any]],
@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
return model_inputs
def preprocess_kto_dataset(
examples: Dict[str, List[Any]],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue
if processor is not None:
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
if examples["response"][i][0]["content"]: # desired example
kto_tag = True
messages = examples["prompt"][i] + [examples["response"][i][0]]
else: # undesired example
kto_tag = False
messages = examples["prompt"][i] + [examples["response"][i][1]]
if kl_response[i][0]["content"]:
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
else:
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, kl_response_ids = template.encode_oneturn(
tokenizer,
kl_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels)
model_inputs["kl_input_ids"].append(kl_input_ids)
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
return model_inputs
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],

View File

@ -137,21 +137,21 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the KTO loss."},
)
kto_chosen_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the desirable losses in KTO training."},
)
kto_rejected_weight: float = field(
default=1.0,
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
)
kto_ftx: float = field(
default=0.0,
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
)
kto_desirable_weight: float = field(
default=1.0,
metadata={"help": "The desirable weight for the KTO loss."},
)
kto_undesirable_weight: float = field(
default=1.0,
metadata={"help": "The undesirable weight for the KTO loss."},
)
orpo_beta: float = field(
default=0.1,
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
)
ppo_buffer_size: int = field(
default=1,
@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},
)

View File

@ -47,11 +47,13 @@ class CustomDPOTrainer(DPOTrainer):
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# dpo hyperparams
self.beta = finetuning_args.dpo_beta
self.label_smoothing = finetuning_args.dpo_label_smoothing
self.loss_type = finetuning_args.dpo_loss
self.ftx_gamma = finetuning_args.dpo_ftx
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@ -143,6 +145,7 @@ class CustomDPOTrainer(DPOTrainer):
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import torch
from transformers import Trainer
@ -13,7 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments
@ -24,6 +24,7 @@ class CustomKTOTrainer(KTOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
):
@ -33,6 +34,7 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.processor = processor
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation
@ -43,15 +45,15 @@ class CustomKTOTrainer(KTOTrainer):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self._peft_has_been_casted_to_bf16 = False
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# KTO parameter
# kto hyperparams
self.beta = finetuning_args.kto_beta
self.desirable_weight = finetuning_args.kto_chosen_weight
self.undesirable_weight = finetuning_args.kto_rejected_weight
self.ftx_gamma = finetuning_args.kto_ftx
self.desirable_weight = finetuning_args.kto_desirable_weight
self.undesirable_weight = finetuning_args.kto_undesirable_weight
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
@ -82,78 +84,85 @@ class CustomKTOTrainer(KTOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps.nanmean()
return -all_logps
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad():
KL_logits = model(
batch["KL_completion_input_ids"],
attention_mask=batch["KL_completion_attention_mask"],
).logits
kl_logits = model(
input_ids=batch["kl_input_ids"],
attention_mask=batch["kl_attention_mask"],
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logits = model(
batch["input_ids"],
target_logits = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
).logits
return_dict=True,
use_cache=False,
).logits.to(torch.float32)
completion_logps = self.get_batch_logps(
completion_logits,
batch["labels"],
target_logps = self.get_batch_logps(
logits=target_logits,
labels=batch["labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
KL_logps = self.get_batch_logps(
KL_logits,
batch["kl_labels"],
kl_logps = self.get_batch_logps(
logits=kl_logits,
labels=batch["kl_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if completion_logps.shape[0] != len(batch["tag"]):
raise ValueError(
"There is a mismatch between the number of examples in this batch and the number of "
"examples for which an output sequence was predicted."
)
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
chosen_logps = completion_logps[chosen_idx, ...]
rejected_logps = completion_logps[rejected_idx, ...]
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
chosen_logits = completion_logits[chosen_idx, ...]
rejected_logits = completion_logits[rejected_idx, ...]
chosen_logps = target_logps[chosen_idx, ...]
rejected_logps = target_logps[rejected_idx, ...]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
chosen_logits = target_logits[chosen_idx, ...]
rejected_logits = target_logits[rejected_idx, ...]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def get_batch_loss_metrics(
self,
model,
batch: Dict[str, Union[List, torch.LongTensor]],
):
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
model: "PreTrainedModel",
batch: Dict[str, "torch.Tensor"],
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
_,
policy_kl_logps,
) = self.forward(model, batch)
with torch.no_grad():
@ -163,27 +172,29 @@ class CustomKTOTrainer(KTOTrainer):
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
reference_kl_logps,
) = self.forward(ref_model, batch)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
policy_kl_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_KL_logps,
reference_kl_logps,
)
losses = losses.nanmean()
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
@ -203,4 +214,4 @@ class CustomKTOTrainer(KTOTrainer):
metrics["kl"] = kl.item()
return losses, metrics
return losses, metrics

View File

@ -48,9 +48,9 @@ def run_kto(
ref_model=ref_model,
args=training_args,
finetuning_args=finetuning_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**tokenizer_module,
**split_dataset(dataset, data_args, training_args),
)

View File

@ -29,7 +29,7 @@ def run_ppo(
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training

View File

@ -9,12 +9,13 @@ from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .dpo import run_dpo
from .kto import run_kto
from .orpo import run_orpo
from .ppo import run_ppo
from .pt import run_pt
from .rm import run_rm
from .sft import run_sft
from .kto import run_kto
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -37,10 +38,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "orpo":
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")