mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
improve KTO impl., replace datasets
Former-commit-id: e56a57ddcf061de6e4acc8679f7dbf0b68364986
This commit is contained in:
parent
e4570e28a8
commit
2bff90719b
34
README.md
34
README.md
@ -45,7 +45,7 @@ Choose your path:
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO and ORPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
@ -69,14 +69,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/05/14] We supported training and inference on the Ascend NPU devices. Check [installation](#installation) section for details.
|
||||
|
||||
[24/05/13] We supported fine-tuning the **Yi-1.5** series models.
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/04/26] We supported fine-tuning the **LLaVA-1.5** multimodal LLMs. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/04/22] We provided a **[Colab notebook](https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing)** for fine-tuning the Llama-3 model on a free T4 GPU. Two Llama-3-derived models fine-tuned using LLaMA Factory are available at Hugging Face, check [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) and [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese) for details.
|
||||
|
||||
[24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage.
|
||||
@ -188,6 +190,7 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## Provided Datasets
|
||||
@ -208,12 +211,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
|
||||
<details><summary>Supervised fine-tuning datasets</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
@ -222,7 +225,6 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
@ -235,15 +237,16 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
@ -259,13 +262,12 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
|
||||
|
||||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
|
34
README_zh.md
34
README_zh.md
@ -45,7 +45,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练和 ORPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
@ -69,14 +69,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/05/14] 我们支持了昇腾 NPU 设备的训练和推理。详情请查阅[安装](#安装-llama-factory)部分。
|
||||
|
||||
[24/05/13] 我们支持了 Yi-1.5 系列模型的微调。
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/04/26] 我们支持了多模态模型 **LLaVA-1.5** 的微调。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/04/22] 我们提供了在免费 T4 GPU 上微调 Llama-3 模型的 **[Colab 笔记本](https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing)**。Hugging Face 社区公开了两个利用 LLaMA Factory 微调的 Llama-3 模型,详情请见 [Llama3-8B-Chinese-Chat](https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat) 和 [Llama3-Chinese](https://huggingface.co/zhichen/Llama3-Chinese)。
|
||||
|
||||
[24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
@ -188,6 +190,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| KTO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| ORPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
## 数据集
|
||||
@ -208,12 +211,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
<details><summary>指令微调数据集</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Identity (en&zh)](data/identity.json)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca-3)
|
||||
- [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Glaive Function Calling V2 (en&zh)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
@ -222,7 +225,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
@ -235,15 +237,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [deepctrl (en&zh)](https://www.modelscope.cn/datasets/deepctrl/deepctrl-sft-data)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [Advertise Generating (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
|
||||
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
|
||||
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
|
||||
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
|
||||
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
|
||||
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
|
||||
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
|
||||
@ -259,13 +262,12 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [DPO mixed (en&zh)](https://huggingface.co/datasets/hiyouga/DPO-En-Zh-20k)
|
||||
- [Open Assistant (zh)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Orca DPO Pairs (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||
- [KTO mixed (en)](https://huggingface.co/datasets/argilla/kto-mix-15k)
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -19,7 +19,10 @@ If you are using a custom dataset, please add your **dataset description** to `d
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations)",
|
||||
"system": "the column name in the dataset containing the system prompts. (default: None)",
|
||||
"tools": "the column name in the dataset containing the tool description. (default: None)",
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)"
|
||||
"images": "the column name in the dataset containing the image inputs. (default: None)",
|
||||
"chosen": "the column name in the dataset containing the chosen answers. (default: None)",
|
||||
"rejected": "the column name in the dataset containing the rejected answers. (default: None)",
|
||||
"kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
|
||||
},
|
||||
"tags (optional, used for the sharegpt format)": {
|
||||
"role_tag": "the key in the message represents the identity. (default: from)",
|
||||
@ -42,13 +45,13 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"instruction": "human instruction (required)",
|
||||
"input": "human input (optional)",
|
||||
"output": "model response (required)",
|
||||
"system": "system prompt (optional)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
["human instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["human instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
@ -69,7 +72,7 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
}
|
||||
```
|
||||
|
||||
The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
The `query` column will be concatenated with the `prompt` column and used as the human prompt, then the human prompt would be `prompt\nquery`. The `response` column represents the model response.
|
||||
|
||||
The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training** in supervised fine-tuning.
|
||||
|
||||
@ -98,12 +101,10 @@ For the **preference datasets**, the `response` column should be a string list w
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
"instruction": "human instruction",
|
||||
"input": "human input",
|
||||
"chosen": "chosen answer",
|
||||
"rejected": "rejected answer"
|
||||
}
|
||||
]
|
||||
```
|
||||
@ -117,7 +118,8 @@ Regarding the above dataset, the description in `dataset_info.json` should be:
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
@ -132,7 +134,7 @@ The dataset in **sharegpt** format should follow the below format:
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "user instruction"
|
||||
"value": "human instruction"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
@ -179,7 +181,7 @@ We also supports the dataset in the **openai** format:
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "user instruction"
|
||||
"content": "human instruction"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
@ -19,7 +19,10 @@
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations)",
|
||||
"system": "数据集代表系统提示的表头名称(默认:None)",
|
||||
"tools": "数据集代表工具描述的表头名称(默认:None)",
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)"
|
||||
"images": "数据集代表图像输入的表头名称(默认:None)",
|
||||
"chosen": "数据集代表更优回复的表头名称(默认:None)",
|
||||
"rejected": "数据集代表更差回复的表头名称(默认:None)",
|
||||
"kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
|
||||
},
|
||||
"tags(可选,用于 sharegpt 格式)": {
|
||||
"role_tag": "消息中代表发送者身份的键名(默认:from)",
|
||||
@ -42,8 +45,8 @@
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"instruction": "人类指令(必填)",
|
||||
"input": "人类输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"system": "系统提示词(选填)",
|
||||
"history": [
|
||||
@ -69,7 +72,7 @@
|
||||
}
|
||||
```
|
||||
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为人类指令,即人类指令为 `prompt\nquery`。`response` 列对应的内容为模型回答。
|
||||
|
||||
`system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意在指令监督学习时,历史消息中的回答**也会被用于训练**。
|
||||
|
||||
@ -98,12 +101,10 @@
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
"instruction": "人类指令",
|
||||
"input": "人类输入",
|
||||
"chosen": "优质回答",
|
||||
"rejected": "劣质回答"
|
||||
}
|
||||
]
|
||||
```
|
||||
@ -117,7 +118,8 @@
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"chosen": "chosen",
|
||||
"rejected": "rejected"
|
||||
}
|
||||
}
|
||||
```
|
||||
@ -132,7 +134,7 @@
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "用户指令"
|
||||
"value": "人类指令"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
@ -165,7 +167,7 @@
|
||||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
其中 `messages` 列应当是一个列表,且符合 `人类/模型/人类/模型/人类/模型` 的顺序。
|
||||
|
||||
我们同样支持 **openai** 格式的数据集:
|
||||
|
||||
@ -179,7 +181,7 @@
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "用户指令"
|
||||
"content": "人类指令"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
@ -1 +0,0 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
@ -1 +0,0 @@
|
||||
a97cf9475291591843976554878568e046d8a46d
|
@ -1 +0,0 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
@ -1 +0,0 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
@ -1 +0,0 @@
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
@ -1 +0,0 @@
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
@ -1,37 +0,0 @@
|
||||
import json
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
import datasets
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
_URL = "examples.json"
|
||||
|
||||
|
||||
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
features = datasets.Features(
|
||||
{
|
||||
"instruction": datasets.Value("string"),
|
||||
"input": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Generator[Tuple[int, Dict[str, Any]], None, None]:
|
||||
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
||||
for key, example in enumerate(example_dataset):
|
||||
yield key, example
|
@ -1 +0,0 @@
|
||||
4748dff00d1dc42768a5b6cc772143c313017812
|
@ -79,5 +79,5 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
break
|
||||
prompt = prompt[:human_idx]
|
||||
|
||||
yield key, {"instruction": query, "output": [r_accept, r_reject], "history": history}
|
||||
yield key, {"instruction": query, "chosen": r_accept, "rejected": r_reject, "history": history}
|
||||
key += 1
|
||||
|
@ -1 +0,0 @@
|
||||
736bcedea2b24a1414765c6d69cbdafaea839f3c
|
30
data/wiki_demo.txt
Normal file
30
data/wiki_demo.txt
Normal file
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
||||
c9cf509b7fdac5490cfd6dae72c2d7b8a60af6cb
|
@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### ORPO Training
|
||||
|
||||
```bash
|
||||
|
@ -53,6 +53,12 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml
|
||||
```
|
||||
|
||||
#### KTO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml
|
||||
```
|
||||
|
||||
#### ORPO 训练
|
||||
|
||||
```bash
|
||||
|
@ -11,7 +11,7 @@ badam_switch_interval: 50
|
||||
badam_verbose: 2
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -12,7 +12,7 @@ lora_target: q_proj,v_proj
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -12,7 +12,7 @@ galore_rank: 128
|
||||
galore_scale: 2.0
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -10,7 +10,7 @@ freeze_trainable_modules: all
|
||||
use_llama_pro: true
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
|
||||
loraplus_lr_ratio: 16.0
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: full
|
||||
mixture_of_depths: convert
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -7,7 +7,7 @@ do_predict: true
|
||||
finetuning_type: full
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
|
@ -11,7 +11,7 @@ ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -11,7 +11,7 @@ lora_target: q_proj,v_proj
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -12,7 +12,7 @@ ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -12,7 +12,7 @@ ddp_timeout: 180000000
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -9,7 +9,7 @@ lora_target: q_proj,v_proj
|
||||
dpo_ftx: 1.0
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
@ -26,7 +26,7 @@ overwrite_output_dir: true
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
|
39
examples/lora_single_gpu/llama3_lora_kto.yaml
Normal file
39
examples/lora_single_gpu/llama3_lora_kto.yaml
Normal file
@ -0,0 +1,39 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
### method
|
||||
stage: kto
|
||||
do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
kto_ftx: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/kto
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
fp16: true
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
evaluation_strategy: steps
|
||||
eval_steps: 500
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
@ -25,7 +25,7 @@ overwrite_output_dir: true
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 0.00001
|
||||
learning_rate: 0.000005
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_steps: 0.1
|
||||
|
@ -9,7 +9,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ do_predict: true
|
||||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: orca_rlhf
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -9,7 +9,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -8,7 +8,7 @@ finetuning_type: lora
|
||||
lora_target: q_proj,v_proj
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_gpt4_en
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
|
@ -1,12 +1,12 @@
|
||||
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
|
@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import Features
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import Role
|
||||
|
||||
|
||||
@ -14,7 +15,13 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "DataArguments") -> List[Any]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
outputs = []
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for image in images:
|
||||
@ -29,7 +36,10 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
||||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
|
||||
r"""
|
||||
Converts alpaca format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
@ -45,23 +55,33 @@ def convert_alpaca(
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag], bool): # kto example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
if examples[dataset_attr.kto_tag]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], str)
|
||||
and isinstance(examples[dataset_attr.rejected][i], str)
|
||||
): # pairwise example
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.chosen][i]},
|
||||
{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.rejected][i]},
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -69,6 +89,9 @@ def convert_alpaca(
|
||||
def convert_sharegpt(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
r"""
|
||||
Converts sharegpt format dataset to the standard format.
|
||||
"""
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
tag_mapping = {
|
||||
@ -88,21 +111,62 @@ def convert_sharegpt(
|
||||
else:
|
||||
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
|
||||
|
||||
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
|
||||
if len(messages) == 0:
|
||||
continue
|
||||
|
||||
aligned_messages = []
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
raise ValueError("Invalid role tag in {}.".format(messages))
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
|
||||
)
|
||||
|
||||
outputs["prompt"].append(aligned_messages[:-1])
|
||||
outputs["response"].append(aligned_messages[-1:])
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(examples[dataset_attr.kto_tag][i], bool): # kto example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
if examples[dataset_attr.kto_tag][i]:
|
||||
response = response + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
else:
|
||||
response = [{"role": Role.ASSISTANT.value, "content": ""}] + response
|
||||
elif (
|
||||
dataset_attr.ranking
|
||||
and isinstance(examples[dataset_attr.chosen][i], dict)
|
||||
and isinstance(examples[dataset_attr.rejected][i], dict)
|
||||
): # pairwise example
|
||||
chosen = examples[dataset_attr.chosen][i]
|
||||
rejected = examples[dataset_attr.rejected][i]
|
||||
if (
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
response = [
|
||||
{"role": tag_mapping[chosen[dataset_attr.role_tag]], "content": chosen[dataset_attr.content_tag]},
|
||||
{"role": tag_mapping[rejected[dataset_attr.role_tag]], "content": rejected[dataset_attr.content_tag]},
|
||||
]
|
||||
else: # normal example
|
||||
prompt = aligned_messages[:-1]
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
continue
|
||||
|
||||
outputs["prompt"].append(prompt)
|
||||
outputs["response"].append(response)
|
||||
outputs["system"].append(system)
|
||||
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
@ -138,7 +202,6 @@ def align_dataset(
|
||||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
"tag": {"dtype": "bool", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
@ -50,35 +50,38 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
def __call__(self, features, return_tensors=None):
|
||||
concatenated_features = []
|
||||
kl_concatenated_features = []
|
||||
tags = []
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
for feature in features:
|
||||
concatenated_features.append(
|
||||
target_features.append(
|
||||
{
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_concatenated_features.append(
|
||||
kl_features.append(
|
||||
{
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
tags.append(feature["tag"])
|
||||
batch = super().__call__(concatenated_features)
|
||||
kl_batch = super().__call__(kl_concatenated_features)
|
||||
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||
kto_tags.append(feature["kto_tags"])
|
||||
|
||||
batch = super().__call__(target_features)
|
||||
kl_batch = super().__call__(kl_features)
|
||||
batch["kl_input_ids"] = kl_batch["input_ids"]
|
||||
batch["kl_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
batch["tag"] = torch.tensor(tags)
|
||||
return batch
|
||||
batch["kto_tags"] = torch.tensor(kto_tags)
|
||||
return batch
|
||||
|
@ -57,7 +57,7 @@ def load_single_dataset(
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
|
||||
if data_path is None:
|
||||
raise ValueError("File extension must be txt, csv, json or jsonl.")
|
||||
@ -116,7 +116,7 @@ def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
@ -25,21 +25,22 @@ class DatasetAttr:
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
""" columns """
|
||||
""" common columns """
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
tag: Optional[bool] = None
|
||||
""" columns for the alpaca format """
|
||||
""" rlhf columns """
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
chosen: Optional[str] = "chosen"
|
||||
rejected: Optional[str] = "rejected"
|
||||
history: Optional[str] = None
|
||||
""" columns for the sharegpt format """
|
||||
""" sharegpt columns """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
""" tags for the sharegpt format """
|
||||
""" sharegpt tags """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
@ -107,11 +108,11 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "images", "tag"]
|
||||
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
column_names.extend(["messages"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
@ -111,102 +111,11 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["tag"].append(examples["tag"])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||
examples['kl_response'] = examples['response'][::-1]
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||
input_ids, labels = [], []
|
||||
kl_input_ids, kl_labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
kl_input_ids += source_ids + target_ids
|
||||
kl_labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
kl_input_ids += [tokenizer.eos_token_id]
|
||||
kl_labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["tag"].append(examples["tag"][i])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||
if desirable == 0 or undesirable == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
@ -352,6 +261,90 @@ def preprocess_pairwise_dataset(
|
||||
return model_inputs
|
||||
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
if examples["response"][i][0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
else: # undesired example
|
||||
kto_tag = False
|
||||
messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
if kl_response[i][0]["content"]:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][0]]
|
||||
else:
|
||||
kl_messages = examples["prompt"][i] + [kl_response[i][1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
@ -380,7 +373,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
stage: Literal["pt", "sft", "rm", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
@ -137,21 +137,21 @@ class RLHFArguments:
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the KTO loss."},
|
||||
)
|
||||
kto_chosen_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
||||
)
|
||||
kto_rejected_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
||||
)
|
||||
kto_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
|
||||
)
|
||||
kto_desirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The desirable weight for the KTO loss."},
|
||||
)
|
||||
kto_undesirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The undesirable weight for the KTO loss."},
|
||||
)
|
||||
orpo_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||
metadata={"help": "The beta (lambda) parameter in the ORPO loss representing the weight of the SFT loss."},
|
||||
)
|
||||
ppo_buffer_size: int = field(
|
||||
default=1,
|
||||
@ -307,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "orpo"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
|
@ -47,11 +47,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
self.ref_model = ref_model
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# dpo hyperparams
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self.label_smoothing = finetuning_args.dpo_label_smoothing
|
||||
self.loss_type = finetuning_args.dpo_loss
|
||||
self.ftx_gamma = finetuning_args.dpo_ftx
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
@ -143,6 +145,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
@ -13,7 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
@ -24,6 +24,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -33,6 +34,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
@ -43,15 +45,15 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
self.ref_model = ref_model
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# KTO parameter
|
||||
# kto hyperparams
|
||||
self.beta = finetuning_args.kto_beta
|
||||
self.desirable_weight = finetuning_args.kto_chosen_weight
|
||||
self.undesirable_weight = finetuning_args.kto_rejected_weight
|
||||
self.ftx_gamma = finetuning_args.kto_ftx
|
||||
self.desirable_weight = finetuning_args.kto_desirable_weight
|
||||
self.undesirable_weight = finetuning_args.kto_undesirable_weight
|
||||
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
@ -82,78 +84,85 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||
return -all_logps.nanmean()
|
||||
|
||||
return -all_logps
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
KL_logits = model(
|
||||
batch["KL_completion_input_ids"],
|
||||
attention_mask=batch["KL_completion_attention_mask"],
|
||||
).logits
|
||||
kl_logits = model(
|
||||
input_ids=batch["kl_input_ids"],
|
||||
attention_mask=batch["kl_attention_mask"],
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
|
||||
completion_logits = model(
|
||||
batch["input_ids"],
|
||||
target_logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
).logits
|
||||
return_dict=True,
|
||||
use_cache=False,
|
||||
).logits.to(torch.float32)
|
||||
|
||||
completion_logps = self.get_batch_logps(
|
||||
completion_logits,
|
||||
batch["labels"],
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
labels=batch["labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
KL_logps = self.get_batch_logps(
|
||||
KL_logits,
|
||||
batch["kl_labels"],
|
||||
kl_logps = self.get_batch_logps(
|
||||
logits=kl_logits,
|
||||
labels=batch["kl_labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
if completion_logps.shape[0] != len(batch["tag"]):
|
||||
raise ValueError(
|
||||
"There is a mismatch between the number of examples in this batch and the number of "
|
||||
"examples for which an output sequence was predicted."
|
||||
)
|
||||
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
|
||||
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
|
||||
if len(target_logps) != len(batch["kto_tags"]):
|
||||
raise ValueError("Mismatched shape of inputs and labels.")
|
||||
|
||||
chosen_logps = completion_logps[chosen_idx, ...]
|
||||
rejected_logps = completion_logps[rejected_idx, ...]
|
||||
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
|
||||
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
|
||||
|
||||
chosen_logits = completion_logits[chosen_idx, ...]
|
||||
rejected_logits = completion_logits[rejected_idx, ...]
|
||||
chosen_logps = target_logps[chosen_idx, ...]
|
||||
rejected_logps = target_logps[rejected_idx, ...]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
||||
chosen_logits = target_logits[chosen_idx, ...]
|
||||
rejected_logits = target_logits[rejected_idx, ...]
|
||||
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
):
|
||||
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, "torch.Tensor"],
|
||||
) -> Tuple["torch.Tensor", Dict[str, "torch.Tensor"]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_KL_logps,
|
||||
_,
|
||||
policy_kl_logps,
|
||||
) = self.forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -163,27 +172,29 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_KL_logps,
|
||||
reference_kl_logps,
|
||||
) = self.forward(ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_KL_logps,
|
||||
policy_kl_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
reference_KL_logps,
|
||||
reference_kl_logps,
|
||||
)
|
||||
losses = losses.nanmean()
|
||||
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
|
||||
|
||||
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
|
||||
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]])
|
||||
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"])
|
||||
|
||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||
@ -203,4 +214,4 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
metrics["kl"] = kl.item()
|
||||
|
||||
return losses, metrics
|
||||
return losses, metrics
|
||||
|
@ -48,9 +48,9 @@ def run_kto(
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
@ -29,7 +29,7 @@ def run_ppo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
|
@ -9,12 +9,13 @@ from ..extras.logging import get_logger
|
||||
from ..hparams import get_infer_args, get_train_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .dpo import run_dpo
|
||||
from .kto import run_kto
|
||||
from .orpo import run_orpo
|
||||
from .ppo import run_ppo
|
||||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
from .kto import run_kto
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@ -37,10 +38,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "orpo":
|
||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "kto":
|
||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "orpo":
|
||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user