refactor mm training

Former-commit-id: 3382317e32f88ed377d3e7759bdeaf0f2559d22a
This commit is contained in:
hiyouga 2024-08-30 02:14:31 +08:00
parent 98b0c7530c
commit a83756b5e9
32 changed files with 505 additions and 472 deletions

View File

@ -47,7 +47,7 @@ Choose your path:
## Features ## Features
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
- **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
@ -72,14 +72,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `use_liger_kernel: true` for efficient training. [24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models.
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
[24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. [24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
<details><summary>Full Changelog</summary> <details><summary>Full Changelog</summary>
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
@ -172,14 +174,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |

View File

@ -48,7 +48,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 项目特色 ## 项目特色
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **多种精度**16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **先进算法**GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
@ -73,14 +73,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `use_liger_kernel: true` 来加速训练。 [24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
[24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。 [24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
<details><summary>展开日志</summary> <details><summary>展开日志</summary>
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
@ -173,14 +175,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | | [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |

View File

@ -38,20 +38,6 @@
"assistant_tag": "assistant" "assistant_tag": "assistant"
} }
}, },
"qwen2vl_demo": {
"file_name": "qwen2vl_demo.json",
"formatting": "sharegpt",
"columns": {
"messages": "messages",
"images": "images"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant"
}
},
"alpaca_en": { "alpaca_en": {
"hf_hub_url": "llamafactory/alpaca_en", "hf_hub_url": "llamafactory/alpaca_en",
"ms_hub_url": "llamafactory/alpaca_en" "ms_hub_url": "llamafactory/alpaca_en"

View File

@ -2,7 +2,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "Who are they?", "content": "<image>Who are they?",
"role": "user" "role": "user"
}, },
{ {
@ -25,7 +25,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "Who is he?", "content": "<image>Who is he?",
"role": "user" "role": "user"
}, },
{ {
@ -48,7 +48,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "Please describe this image", "content": "<image>Please describe this image",
"role": "user" "role": "user"
}, },
{ {
@ -71,7 +71,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "他们是谁?", "content": "<image>他们是谁?",
"role": "user" "role": "user"
}, },
{ {
@ -94,7 +94,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "他是谁?", "content": "<image>他是谁?",
"role": "user" "role": "user"
}, },
{ {
@ -117,7 +117,7 @@
{ {
"messages": [ "messages": [
{ {
"content": "请描述这张图片", "content": "<image>请描述这张图片",
"role": "user" "role": "user"
}, },
{ {

View File

@ -1,140 +0,0 @@
[
{
"messages": [
{
"content": "<|image_pad|>Who are they?",
"role": "user"
},
{
"content": "They're Kane and Gretzka from Bayern Munich.",
"role": "assistant"
},
{
"content": "What are they doing?",
"role": "user"
},
{
"content": "They are celebrating on the soccer field.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>Who is he?",
"role": "user"
},
{
"content": "He's Thomas Muller from Bayern Munich.",
"role": "assistant"
},
{
"content": "<|image_pad|>Why is he on the ground?",
"role": "user"
},
{
"content": "Because he's sliding on his knees to celebrate.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg","mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>Please describe this image",
"role": "user"
},
{
"content": "Chinese astronaut Gui Haichao is giving a speech.",
"role": "assistant"
},
{
"content": "What has he accomplished?",
"role": "user"
},
{
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>他们是谁?",
"role": "user"
},
{
"content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
"role": "assistant"
},
{
"content": "<|image_pad|>他们在做什么?",
"role": "user"
},
{
"content": "他们在足球场上庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/1.jpg","mllm_demo_data/1.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>他是谁?",
"role": "user"
},
{
"content": "他是来自拜仁慕尼黑的托马斯·穆勒。",
"role": "assistant"
},
{
"content": "他为什么在地上?",
"role": "user"
},
{
"content": "因为他正在双膝跪地滑行庆祝。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/2.jpg"
]
},
{
"messages": [
{
"content": "<|image_pad|>请描述这张图片",
"role": "user"
},
{
"content": "中国宇航员桂海潮正在讲话。",
"role": "assistant"
},
{
"content": "他取得过哪些成就?",
"role": "user"
},
{
"content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。",
"role": "assistant"
}
],
"images": [
"mllm_demo_data/3.jpg"
]
}
]

View File

@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### Reward Modeling #### Reward Modeling

View File

@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
```bash ```bash
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练

View File

@ -1,40 +0,0 @@
### model
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
visual_inputs: true
### method
stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json
### dataset
dataset: qwen2vl_demo
template: qwen2vl
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2-vl-7b/full/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 1
learning_rate: 1.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 500

View File

@ -1,5 +1,5 @@
### model ### model
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
visual_inputs: true visual_inputs: true
### method ### method
@ -9,23 +9,23 @@ finetuning_type: lora
lora_target: all lora_target: all
### dataset ### dataset
dataset: qwen2vl_demo dataset: mllm_demo
template: qwen2vl template: qwen2_vl
cutoff_len: 1024 cutoff_len: 1024
max_samples: 1000 max_samples: 1000
overwrite_cache: true overwrite_cache: true
preprocessing_num_workers: 16 preprocessing_num_workers: 16
### output ### output
output_dir: saves/qwen2-vl-7b/lora/sft output_dir: saves/qwen2_vl-7b/lora/sft
logging_steps: 10 logging_steps: 10
save_steps: 500 save_steps: 500
plot_loss: true plot_loss: true
overwrite_output_dir: true overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 2 per_device_train_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine

View File

@ -1,4 +1,4 @@
transformers>=4.41.2,<=4.43.4 transformers>=4.41.2,<=4.45.0
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.33.0 accelerate>=0.30.1,<=0.33.0
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0

View File

@ -20,7 +20,7 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.43.4 transformers>=4.41.2,<=4.44.3
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=2.21.0
accelerate>=0.30.1,<=0.33.0 accelerate>=0.30.1,<=0.33.0
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
@ -28,9 +28,9 @@ Dependency graph:
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.43.4 transformers>=4.41.2,<=4.44.3
packing: packing:
transformers>=4.41.2,<=4.43.4 transformers>=4.41.2,<=4.44.3
""" """
from .cli import VERSION from .cli import VERSION

View File

@ -22,6 +22,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
@ -31,7 +32,6 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from numpy.typing import NDArray
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
from trl import PreTrainedModelWrapper from trl import PreTrainedModelWrapper
from ..data import Template from ..data import Template
@ -81,27 +81,19 @@ class HuggingfaceEngine(BaseEngine):
image: Optional["NDArray"] = None, image: Optional["NDArray"] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
if ( if image is not None:
processor is not None if IMAGE_PLACEHOLDER not in messages[0]["content"]:
and image is not None messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
and not hasattr(processor, "image_seq_length")
and template.image_token not in messages[0]["content"] messages = template.mm_plugin.process_messages(messages, [image], processor)
): # llava-like models
messages[0]["content"] = template.image_token + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] system = system or generating_args["default_system"]
pixel_values = None
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
) )
if processor is not None and image is not None: # add image features if image is not None:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
batch_feature = image_processor(image, return_tensors="pt")
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
if hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
@ -164,8 +156,13 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(), logits_processor=get_logits_processor(),
) )
if pixel_values is not None: if image is not None:
gen_kwargs["pixel_values"] = pixel_values mm_inputs = template.mm_plugin.get_mm_inputs(
images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor
)
for key, value in mm_inputs.items():
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@ -12,13 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask from .collator import (
CustomDataCollatorForSeq2Seq,
KTODataCollatorWithPadding,
PairwiseDataCollatorWithPadding,
SFTDataCollatorWith4DAttentionMask,
)
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__ = [ __all__ = [
"CustomDataCollatorForSeq2Seq",
"KTODataCollatorWithPadding", "KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding", "PairwiseDataCollatorWithPadding",
"SFTDataCollatorWith4DAttentionMask", "SFTDataCollatorWith4DAttentionMask",

View File

@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@dataclass @dataclass
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq): class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r""" r"""
Data collator for 4d attention mask. Data collator for custom models (like Qwen2-VL).
""" """
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
image_grid_thw = None image_grid_thw = None
if "image_grid_thw" in features[0]: if "image_grid_thw" in features[0]:
@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0 torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
] ]
if image_grid_thw_list: if image_grid_thw_list:
image_grid_thw = torch.cat(image_grid_thw_list, 0) image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
pixel_values = torch.cat(pixel_values_list, dim=0)
else: else:
# Handle the case where the list is empty, for example:
image_grid_thw = None image_grid_thw = None
if pixel_values_list:
pixel_values = torch.cat(pixel_values_list, 0)
else:
# Handle the case where the list is empty, for example:
pixel_values = None pixel_values = None
features = [ features = [
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]} {key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
for feature in features for feature in features
] ]
features = super().__call__(features) features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
if image_grid_thw is not None: if image_grid_thw is not None:
features["image_grid_thw"] = image_grid_thw features["image_grid_thw"] = image_grid_thw
features["pixel_values"] = pixel_values features["pixel_values"] = pixel_values
@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@dataclass @dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
r"""
Data collator for 4d attention mask.
"""
block_diag_attn: bool = False
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
return features
@dataclass
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
""" """
@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"attention_mask": feature["{}_attention_mask".format(key)], "attention_mask": feature["{}_attention_mask".format(key)],
"labels": feature["{}_labels".format(key)], "labels": feature["{}_labels".format(key)],
} }
if "pixel_values" in feature: if "pixel_values" in feature: # image data are same for chosen and rejected
target_feature["pixel_values"] = feature["pixel_values"] target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "{}_token_type_ids".format(key) in feature: if "{}_token_type_ids".format(key) in feature:
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
@dataclass @dataclass
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
r""" r"""
Data collator for KTO data. Data collator for KTO data.
""" """
@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
if "pixel_values" in feature: if "pixel_values" in feature:
target_feature["pixel_values"] = feature["pixel_values"] target_feature["pixel_values"] = feature["pixel_values"]
if "image_grid_thw" in feature:
target_feature["image_grid_thw"] = feature["image_grid_thw"]
if "token_type_ids" in feature: if "token_type_ids" in feature:
target_feature["token_type_ids"] = feature["token_type_ids"] target_feature["token_type_ids"] = feature["token_type_ids"]
kl_feature["token_type_ids"] = feature["kl_token_type_ids"] kl_feature["token_type_ids"] = feature["kl_token_type_ids"]

View File

@ -0,0 +1,271 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from PIL.Image import Image
from transformers import ProcessorMixin
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
from ..extras.packages import is_pillow_available
if is_pillow_available():
import torch
from PIL import Image
if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
r"""
Processes visual inputs. (currently only supports a single image)
Returns:
pixel_values: tensor with shape (B, C, H, W)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor([image], return_tensors="pt")["pixel_values"]
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
r"""
Gets paligemma token type ids for computing loss.
Returns:
token_type_ids: shape (1, seq_len)
"""
image_seq_length = getattr(processor, "image_seq_length")
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
def get_qwen2vl_image_inputs(
images: Sequence["ImageObject"], processor: "ProcessorMixin"
) -> Dict[str, "torch.Tensor"]:
r"""
Processes qwen2-vl visual inputs. Supports multiple images.
Returns:
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0 # fake image
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
class BasePlugin:
def __init__(self, image_token: str) -> None:
self.image_token = image_token
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
return messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return {}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
return
class LlavaPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("Llava model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
new_messages.append({"role": message["role"], "content": content})
return new_messages
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return {"pixel_values": get_pixel_values(images, processor)}
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
class PaliGemmaPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_count = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_count += 1
if image_count > 1:
raise ValueError("PaliGemma model only accepts one image per sample.")
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
new_messages.append({"role": message["role"], "content": content})
return new_messages
def process_token_ids(
self,
input_ids: List[int],
labels: Optional[List[int]],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image_seq_length: int = getattr(image_processor, "image_seq_length")
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seq_length + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seq_length + labels
return input_ids, labels
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
for feature_name, feature_length in feature_seqlens.items():
mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
return mm_inputs
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
for feature_name in feature_seqlens.keys():
model_inputs[feature_name].append(mm_inputs[feature_name][0])
class Qwen2vlPlugin(BasePlugin):
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageObject"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
new_messages = []
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(
IMAGE_PLACEHOLDER,
"<|vision_start|>{}<|vision_end|>".format(
self.image_token * (image_grid_thw[index].prod() // merge_length)
),
1,
)
index += 1
new_messages.append({"role": message["role"], "content": content})
return new_messages
def get_mm_inputs(
self,
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Any]:
return get_qwen2vl_image_inputs(images, processor)
def process_model_inputs(
self,
model_inputs: Dict[str, List[Any]],
images: Sequence["ImageObject"],
feature_seqlens: Dict[str, int],
processor: Optional["ProcessorMixin"],
) -> None:
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
PLUGINS = {
"llava": LlavaPlugin,
"paligemma": PaliGemmaPlugin,
"qwen2_vl": Qwen2vlPlugin,
}
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
if name not in PLUGINS:
raise ValueError("{} not found.".format(name))
return PLUGINS[name](image_token)

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,9 +41,6 @@ def _encode_feedback_example(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]: ) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True
messages = prompt + [response[0]] messages = prompt + [response[0]]
@ -62,10 +60,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id] response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id] kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token) kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -91,28 +87,15 @@ def preprocess_feedback_dataset(
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1] kl_response = examples["response"][::-1]
model_inputs = { model_inputs = defaultdict(list)
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i], prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
kl_response=kl_response[i], kl_response=kl_response[i],
system=examples["system"][i], system=examples["system"][i],
@ -129,11 +112,15 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids)) model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels) model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag) model_inputs["kto_tags"].append(kto_tag)
if processor is not None: template.mm_plugin.process_model_inputs(
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs=model_inputs,
if hasattr(processor, "image_seq_length"): # paligemma models images=examples["images"][i],
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) feature_seqlens={
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor)) "token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num undesirable_num = len(model_inputs["kto_tags"]) - desirable_num

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -39,9 +40,6 @@ def _encode_pairwise_example(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int]]: ) -> Tuple[List[int], List[int], List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
chosen_messages = prompt + [response[0]] chosen_messages = prompt + [response[0]]
rejected_messages = prompt + [response[1]] rejected_messages = prompt + [response[1]]
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
@ -51,10 +49,7 @@ def _encode_pairwise_example(
chosen_ids += [tokenizer.eos_token_id] chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id] rejected_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
# consider the response is more important # consider the response is more important
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len) source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
prompt_ids = prompt_ids[:source_len] prompt_ids = prompt_ids[:source_len]
@ -77,27 +72,15 @@ def preprocess_pairwise_dataset(
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = { model_inputs = defaultdict(list)
"chosen_input_ids": [],
"chosen_attention_mask": [],
"chosen_labels": [],
"rejected_input_ids": [],
"rejected_attention_mask": [],
"rejected_labels": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["chosen_token_type_ids"] = []
model_inputs["rejected_token_type_ids"] = []
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
prompt=examples["prompt"][i], prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
@ -112,15 +95,15 @@ def preprocess_pairwise_dataset(
model_inputs["rejected_input_ids"].append(rejected_input_ids) model_inputs["rejected_input_ids"].append(rejected_input_ids)
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids)) model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
model_inputs["rejected_labels"].append(rejected_labels) model_inputs["rejected_labels"].append(rejected_labels)
if processor is not None: template.mm_plugin.process_model_inputs(
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs=model_inputs,
if hasattr(processor, "image_seq_length"): # paligemma models images=examples["images"][i],
model_inputs["chosen_token_type_ids"].append( feature_seqlens={
get_paligemma_token_type_ids(len(chosen_input_ids), processor) "chosen_token_type_ids": len(chosen_input_ids),
) "rejected_token_type_ids": len(rejected_input_ids),
model_inputs["rejected_token_type_ids"].append( },
get_paligemma_token_type_ids(len(rejected_input_ids), processor) processor=processor,
) )
return model_inputs return model_inputs

View File

@ -13,20 +13,7 @@
# limitations under the License. # limitations under the License.
import bisect import bisect
from typing import TYPE_CHECKING, List, Sequence, Tuple from typing import List, Sequence, Tuple
from ...extras.packages import is_pillow_available
if is_pillow_available():
from PIL import Image
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image as ImageObject
from transformers import ProcessorMixin
from transformers.image_processing_utils import BaseImageProcessor
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
@ -61,37 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return knapsacks return knapsacks
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
r"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length = getattr(processor, "image_seq_length")
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
r"""
Processes visual inputs. support multi images
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if len(images) != 0:
image_inputs = image_processor(images=images, return_tensors="pt")
else:
image = Image.new("RGB", (56, 56), (255, 255, 255))
image_inputs = image_processor(images=[image], return_tensors="pt")
image_inputs["image_grid_thw"][0][0] = 0
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r""" r"""
Computes the real sequence length after truncation by the cutoff_len. Computes the real sequence length after truncation by the cutoff_len.

View File

@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from .processor_utils import ( from .processor_utils import greedy_knapsack, infer_seqlen
get_paligemma_token_type_ids,
get_pixel_values,
get_qwen2vl_image_inputs,
greedy_knapsack,
infer_seqlen,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from PIL.Image import Image as ImageObject
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from ...hparams import DataArguments from ...hparams import DataArguments
@ -43,41 +36,15 @@ def _encode_supervised_example(
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
template: "Template", template: "Template",
images: Sequence["ImageObject"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
train_on_prompt: bool, train_on_prompt: bool,
mask_history: bool, mask_history: bool,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
image_processor = getattr(processor, "image_processor")
merge_length = image_processor.merge_size**2
if len(images) > 0:
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
index = 0
for message in prompt:
content = message["content"]
while "<|image_pad|>" in content:
content = content.replace(
"<|image_pad|>",
template.vision_start_token
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
+ template.vision_end_token,
1,
)
index += 1
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
messages = prompt + response messages = prompt + response
input_ids, labels = [], [] input_ids, labels = [], []
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = 1 if template.efficient_eos else 0 total_length = 1 if template.efficient_eos else 0
@ -125,28 +92,21 @@ def preprocess_supervised_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair. # for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
model_inputs["image_grid_thw"] = []
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_supervised_example( input_ids, labels = _encode_supervised_example(
prompt=examples["prompt"][i], prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
images=examples["images"][i],
template=template, template=template,
tokenizer=tokenizer, tokenizer=tokenizer,
processor=processor, processor=processor,
@ -157,15 +117,12 @@ def preprocess_supervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: template.mm_plugin.process_model_inputs(
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models model_inputs=model_inputs,
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor) images=examples["images"][i],
model_inputs["pixel_values"].append(image_inputs["pixel_values"]) feature_seqlens={"token_type_ids": len(input_ids)},
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"]) processor=processor,
else: )
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
return model_inputs return model_inputs
@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset(
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[Any]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num = 0 valid_num = 0
@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset(
batch_labels.append(labels) batch_labels.append(labels)
valid_num += 1 valid_num += 1
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
for knapsack in knapsacks: for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_input_ids, packed_attention_masks, packed_labels = [], [], []

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..data_utils import Role from ..data_utils import Role
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen from .processor_utils import infer_seqlen
if TYPE_CHECKING: if TYPE_CHECKING:
@ -39,9 +40,6 @@ def _encode_unsupervised_example(
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
cutoff_len: int, cutoff_len: int,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response
else: else:
@ -51,10 +49,7 @@ def _encode_unsupervised_example(
if template.efficient_eos: if template.efficient_eos:
labels += [tokenizer.eos_token_id] labels += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len) source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
input_ids = input_ids[:source_len] input_ids = input_ids[:source_len]
labels = labels[:target_len] labels = labels[:target_len]
@ -69,19 +64,15 @@ def preprocess_unsupervised_dataset(
data_args: "DataArguments", data_args: "DataArguments",
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = defaultdict(list)
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1: if len(examples["prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i])) logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels = _encode_unsupervised_example( input_ids, labels = _encode_unsupervised_example(
prompt=examples["prompt"][i], prompt=prompt,
response=examples["response"][i], response=examples["response"][i],
system=examples["system"][i], system=examples["system"][i],
tools=examples["tools"][i], tools=examples["tools"][i],
@ -93,10 +84,12 @@ def preprocess_unsupervised_dataset(
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
if processor is not None: template.mm_plugin.process_model_inputs(
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor)) model_inputs=model_inputs,
if hasattr(processor, "image_seq_length"): # paligemma models images=examples["images"][i],
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor)) feature_seqlens={"token_type_ids": len(input_ids)},
processor=processor,
)
return model_inputs return model_inputs

View File

@ -15,9 +15,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import BasePlugin, get_mm_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
@ -41,11 +43,9 @@ class Template:
format_prefix: "Formatter" format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str
vision_start_token: str
vision_end_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
self, self,
@ -207,11 +207,9 @@ def _register_template(
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Sequence[str] = [], stop_words: Sequence[str] = [],
image_token: str = "<image>",
vision_start_token: str = "<|vision_start|>",
vision_end_token: str = "<|vision_end|>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@ -258,11 +256,9 @@ def _register_template(
format_prefix=format_prefix or default_prefix_formatter, format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token,
vision_start_token=vision_start_token,
vision_end_token=vision_end_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
mm_plugin=mm_plugin,
) )
@ -722,6 +718,17 @@ _register_template(
) )
_register_template(
name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
)
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
@ -766,6 +773,19 @@ _register_template(
) )
_register_template(
name="paligemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@ -790,17 +810,15 @@ _register_template(
_register_template( _register_template(
name="qwen2vl", name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.", default_system="You are a helpful assistant.",
image_token="<|image_pad|>",
vision_start_token="<|vision_start|>",
vision_end_token="<|vision_end|>",
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
replace_eos=True, replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
) )
@ -915,6 +933,7 @@ _register_template(
), ),
stop_words=["###"], stop_words=["###"],
efficient_eos=True, efficient_eos=True,
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
) )

View File

@ -47,6 +47,8 @@ FILEEXT2TYPE = {
IGNORE_INDEX = -100 IGNORE_INDEX = -100
IMAGE_PLACEHOLDER = "<image>"
LAYERNORM_NAMES = {"norm", "ln"} LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
@ -785,7 +787,7 @@ register_model_group(
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
}, },
}, },
template="vicuna", template="llava",
vision=True, vision=True,
) )
@ -930,27 +932,28 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"PaliGemma-3B-pt-224": { "PaliGemma-3B-pt-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
}, },
"PaliGemma-3B-pt-448": { "PaliGemma-3B-pt-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
}, },
"PaliGemma-3B-pt-896": { "PaliGemma-3B-pt-896-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
}, },
"PaliGemma-3B-mix-224": { "PaliGemma-3B-mix-224-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
}, },
"PaliGemma-3B-mix-448": { "PaliGemma-3B-mix-448-Chat": {
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
}, },
}, },
template="paligemma",
vision=True, vision=True,
) )
@ -1329,6 +1332,34 @@ register_model_group(
) )
register_model_group(
models={
"Qwen2VL-2B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
},
"Qwen2VL-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
},
"Qwen2VL-2B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
},
"Qwen2VL-2B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
},
"Qwen2VL-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
},
"Qwen2VL-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
},
},
template="qwen2_vl",
vision=True,
)
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B": { "SOLAR-10.7B": {

View File

@ -79,7 +79,7 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0") require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")

View File

@ -117,7 +117,7 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
use_liger_kernel: bool = field( enable_liger_kernel: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to enable liger kernel for faster training."}, metadata={"help": "Whether or not to enable liger kernel for faster training."},
) )

View File

@ -116,7 +116,7 @@ def _check_extra_dependencies(
if model_args.use_unsloth: if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.use_liger_kernel: if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel") require_version("liger-kernel", "To fix: pip install liger-kernel")
if model_args.mixture_of_depths is not None: if model_args.mixture_of_depths is not None:

View File

@ -27,7 +27,7 @@ logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.use_liger_kernel: if not is_trainable or not model_args.enable_liger_kernel:
return return
if getattr(config, "model_type", None) == "gemma": if getattr(config, "model_type", None) == "gemma":

View File

@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -36,11 +36,14 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output") forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]: elif model.config.model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector") forbidden_modules.add("multi_modal_projector")
elif model.config.model_type in ["qwen2_vl"]: elif model.config.model_type == "qwen2_vl":
forbidden_modules.add("merger") forbidden_modules.add("merger")
if freeze_vision_tower: if freeze_vision_tower:
forbidden_modules.add("vision_tower") if model.config.model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():

View File

@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4") require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils

View File

@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
if "pixel_values" in batch: if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"] model_inputs["pixel_values"] = batch["pixel_values"]
if "image_grid_thw" in batch:
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
if "{}token_type_ids".format(prefix) in batch: if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]

View File

@ -17,9 +17,7 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from transformers import DataCollatorWithPadding from ...data import CustomDataCollatorForSeq2Seq, get_dataset
from ...data import get_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint from ..callbacks import fix_valuehead_checkpoint
@ -47,7 +45,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)

View File

@ -115,7 +115,7 @@ class Runner:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
use_liger_kernel=(get("top.booster") == "liger_kernel"), enable_liger_kernel=(get("top.booster") == "liger_kernel"),
visual_inputs=get("top.visual_inputs"), visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"), dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")), dataset=",".join(get("train.dataset")),