mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
refactor mm training
Former-commit-id: 3382317e32f88ed377d3e7759bdeaf0f2559d22a
This commit is contained in:
parent
98b0c7530c
commit
a83756b5e9
15
README.md
15
README.md
@ -47,7 +47,7 @@ Choose your path:
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
- **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||
@ -72,14 +72,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `use_liger_kernel: true` for efficient training.
|
||||
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models.
|
||||
|
||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
||||
|
||||
[24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
|
||||
|
||||
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
|
||||
|
||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
|
||||
|
||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
|
||||
@ -172,14 +174,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
|
15
README_zh.md
15
README_zh.md
@ -48,7 +48,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||
@ -73,14 +73,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `use_liger_kernel: true` 来加速训练。
|
||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。
|
||||
|
||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
|
||||
|
||||
[24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
|
||||
|
||||
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
|
||||
|
||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
|
||||
|
||||
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
|
||||
@ -173,14 +175,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B | cpm |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
|
@ -38,20 +38,6 @@
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"qwen2vl_demo": {
|
||||
"file_name": "qwen2vl_demo.json",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages",
|
||||
"images": "images"
|
||||
},
|
||||
"tags": {
|
||||
"role_tag": "role",
|
||||
"content_tag": "content",
|
||||
"user_tag": "user",
|
||||
"assistant_tag": "assistant"
|
||||
}
|
||||
},
|
||||
"alpaca_en": {
|
||||
"hf_hub_url": "llamafactory/alpaca_en",
|
||||
"ms_hub_url": "llamafactory/alpaca_en"
|
||||
|
@ -2,7 +2,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Who are they?",
|
||||
"content": "<image>Who are they?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
@ -25,7 +25,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Who is he?",
|
||||
"content": "<image>Who is he?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
@ -48,7 +48,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "Please describe this image",
|
||||
"content": "<image>Please describe this image",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
@ -71,7 +71,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "他们是谁?",
|
||||
"content": "<image>他们是谁?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
@ -94,7 +94,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "他是谁?",
|
||||
"content": "<image>他是谁?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
@ -117,7 +117,7 @@
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "请描述这张图片",
|
||||
"content": "<image>请描述这张图片",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
|
@ -1,140 +0,0 @@
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>Who are they?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They're Kane and Gretzka from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What are they doing?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "They are celebrating on the soccer field.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/1.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>Who is he?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He's Thomas Muller from Bayern Munich.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "<|image_pad|>Why is he on the ground?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Because he's sliding on his knees to celebrate.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/2.jpg","mllm_demo_data/2.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>Please describe this image",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "Chinese astronaut Gui Haichao is giving a speech.",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "What has he accomplished?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/3.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>他们是谁?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "<|image_pad|>他们在做什么?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "他们在足球场上庆祝。",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/1.jpg","mllm_demo_data/1.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>他是谁?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "他是来自拜仁慕尼黑的托马斯·穆勒。",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "他为什么在地上?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "因为他正在双膝跪地滑行庆祝。",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/2.jpg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"content": "<|image_pad|>请描述这张图片",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "中国宇航员桂海潮正在讲话。",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"content": "他取得过哪些成就?",
|
||||
"role": "user"
|
||||
},
|
||||
{
|
||||
"content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
"mllm_demo_data/3.jpg"
|
||||
]
|
||||
}
|
||||
]
|
@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Reward Modeling
|
||||
|
@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
|
||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 奖励模型训练
|
||||
|
@ -1,40 +0,0 @@
|
||||
### model
|
||||
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
|
||||
visual_inputs: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: qwen2vl_demo
|
||||
template: qwen2vl
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2-vl-7b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
@ -1,5 +1,5 @@
|
||||
### model
|
||||
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
|
||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||
visual_inputs: true
|
||||
|
||||
### method
|
||||
@ -9,23 +9,23 @@ finetuning_type: lora
|
||||
lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: qwen2vl_demo
|
||||
template: qwen2vl
|
||||
dataset: mllm_demo
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2-vl-7b/lora/sft
|
||||
output_dir: saves/qwen2_vl-7b/lora/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 2
|
||||
gradient_accumulation_steps: 1
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
|
@ -1,4 +1,4 @@
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
peft>=0.11.1,<=0.12.0
|
||||
|
@ -20,7 +20,7 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
transformers>=4.41.2,<=4.44.3
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
peft>=0.11.1,<=0.12.0
|
||||
@ -28,9 +28,9 @@ Dependency graph:
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
transformers>=4.41.2,<=4.44.3
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.43.4
|
||||
transformers>=4.41.2,<=4.44.3
|
||||
"""
|
||||
|
||||
from .cli import VERSION
|
||||
|
@ -22,6 +22,7 @@ import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
@ -31,7 +32,6 @@ from .base_engine import BaseEngine, Response
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from ..data import Template
|
||||
@ -81,27 +81,19 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image: Optional["NDArray"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if (
|
||||
processor is not None
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and template.image_token not in messages[0]["content"]
|
||||
): # llava-like models
|
||||
messages[0]["content"] = template.image_token + messages[0]["content"]
|
||||
if image is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(messages, [image], processor)
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or generating_args["default_system"]
|
||||
pixel_values = None
|
||||
prompt_ids, _ = template.encode_oneturn(
|
||||
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
|
||||
)
|
||||
if processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
if image is not None:
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
@ -164,8 +156,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
gen_kwargs["pixel_values"] = pixel_values
|
||||
if image is not None:
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(
|
||||
images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor
|
||||
)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
|
@ -12,13 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
|
||||
from .collator import (
|
||||
CustomDataCollatorForSeq2Seq,
|
||||
KTODataCollatorWithPadding,
|
||||
PairwiseDataCollatorWithPadding,
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
)
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CustomDataCollatorForSeq2Seq",
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
|
@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
Data collator for custom models (like Qwen2-VL).
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
image_grid_thw = None
|
||||
if "image_grid_thw" in features[0]:
|
||||
@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||
]
|
||||
if image_grid_thw_list:
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, 0)
|
||||
image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
image_grid_thw = None
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, 0)
|
||||
else:
|
||||
# Handle the case where the list is empty, for example:
|
||||
pixel_values = None
|
||||
|
||||
features = [
|
||||
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||
for feature in features
|
||||
]
|
||||
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
if image_grid_thw is not None:
|
||||
features["image_grid_thw"] = image_grid_thw
|
||||
features["pixel_values"] = pixel_values
|
||||
@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
}
|
||||
if "pixel_values" in feature:
|
||||
if "pixel_values" in feature: # image data are same for chosen and rejected
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "{}_token_type_ids".format(key) in feature:
|
||||
target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
|
||||
|
||||
@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
if "pixel_values" in feature:
|
||||
target_feature["pixel_values"] = feature["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in feature:
|
||||
target_feature["image_grid_thw"] = feature["image_grid_thw"]
|
||||
|
||||
if "token_type_ids" in feature:
|
||||
target_feature["token_type_ids"] = feature["token_type_ids"]
|
||||
kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
|
||||
|
271
src/llamafactory/data/mm_plugin.py
Normal file
271
src/llamafactory/data/mm_plugin.py
Normal file
@ -0,0 +1,271 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from PIL.Image import Image
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor([image], return_tensors="pt")["pixel_values"]
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
|
||||
Returns:
|
||||
token_type_ids: shape (1, seq_len)
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
|
||||
|
||||
|
||||
def get_qwen2vl_image_inputs(
|
||||
images: Sequence["ImageObject"], processor: "ProcessorMixin"
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes qwen2-vl visual inputs. Supports multiple images.
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if len(images) != 0:
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else:
|
||||
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||
image_inputs["image_grid_thw"][0][0] = 0 # fake image
|
||||
|
||||
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
def __init__(self, image_token: str) -> None:
|
||||
self.image_token = image_token
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_count = 0
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_count += 1
|
||||
if image_count > 1:
|
||||
raise ValueError("Llava model only accepts one image per sample.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
|
||||
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
return {"pixel_values": get_pixel_values(images, processor)}
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_count = 0
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_count += 1
|
||||
if image_count > 1:
|
||||
raise ValueError("PaliGemma model only accepts one image per sample.")
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
||||
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
|
||||
return new_messages
|
||||
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
labels: Optional[List[int]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image_seq_length: int = getattr(image_processor, "image_seq_length")
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
input_ids = [image_token_id] * image_seq_length + input_ids
|
||||
if labels is not None:
|
||||
labels = [IGNORE_INDEX] * image_seq_length + labels
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
|
||||
for feature_name, feature_length in feature_seqlens.items():
|
||||
mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
|
||||
for feature_name in feature_seqlens.keys():
|
||||
model_inputs[feature_name].append(mm_inputs[feature_name][0])
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageObject"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if len(images) > 0:
|
||||
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
||||
|
||||
index = 0
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
"<|vision_start|>{}<|vision_end|>".format(
|
||||
self.image_token * (image_grid_thw[index].prod() // merge_length)
|
||||
),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
|
||||
new_messages.append({"role": message["role"], "content": content})
|
||||
|
||||
return new_messages
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Any]:
|
||||
return get_qwen2vl_image_inputs(images, processor)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
model_inputs: Dict[str, List[Any]],
|
||||
images: Sequence["ImageObject"],
|
||||
feature_seqlens: Dict[str, int],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> None:
|
||||
mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
|
||||
model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
|
||||
model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"llava": LlavaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
}
|
||||
|
||||
|
||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
|
||||
if name not in PLUGINS:
|
||||
raise ValueError("{} not found.".format(name))
|
||||
|
||||
return PLUGINS[name](image_token)
|
@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,9 +41,6 @@ def _encode_feedback_example(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@ -62,10 +60,8 @@ def _encode_feedback_example(
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
kl_response_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -91,28 +87,15 @@ def preprocess_feedback_dataset(
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["response"][::-1]
|
||||
model_inputs = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"labels": [],
|
||||
"kl_input_ids": [],
|
||||
"kl_attention_mask": [],
|
||||
"kl_labels": [],
|
||||
"kto_tags": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
model_inputs["kl_token_type_ids"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
kl_response=kl_response[i],
|
||||
system=examples["system"][i],
|
||||
@ -129,11 +112,15 @@ def preprocess_feedback_dataset(
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["kto_tags"].append(kto_tag)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={
|
||||
"token_type_ids": len(input_ids),
|
||||
"kl_token_type_ids": len(kl_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
|
@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -39,9 +40,6 @@ def _encode_pairwise_example(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
@ -51,10 +49,7 @@ def _encode_pairwise_example(
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
|
||||
# consider the response is more important
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
@ -77,27 +72,15 @@ def preprocess_pairwise_dataset(
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {
|
||||
"chosen_input_ids": [],
|
||||
"chosen_attention_mask": [],
|
||||
"chosen_labels": [],
|
||||
"rejected_input_ids": [],
|
||||
"rejected_attention_mask": [],
|
||||
"rejected_labels": [],
|
||||
}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"] = []
|
||||
model_inputs["rejected_token_type_ids"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
@ -112,15 +95,15 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs["rejected_input_ids"].append(rejected_input_ids)
|
||||
model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
|
||||
model_inputs["rejected_labels"].append(rejected_labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["chosen_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(chosen_input_ids), processor)
|
||||
)
|
||||
model_inputs["rejected_token_type_ids"].append(
|
||||
get_paligemma_token_type_ids(len(rejected_input_ids), processor)
|
||||
)
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={
|
||||
"chosen_token_type_ids": len(chosen_input_ids),
|
||||
"rejected_token_type_ids": len(rejected_input_ids),
|
||||
},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -13,20 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
|
||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
|
||||
@ -61,37 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
|
||||
return knapsacks
|
||||
|
||||
|
||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
r"""
|
||||
Processes visual inputs. (currently only supports a single image)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
|
||||
return image_processor(image, return_tensors="pt")["pixel_values"][0] # shape (C, H, W)
|
||||
|
||||
|
||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
|
||||
r"""
|
||||
Gets paligemma token type ids for computing loss.
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
|
||||
|
||||
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||
r"""
|
||||
Processes visual inputs. support multi images
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if len(images) != 0:
|
||||
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||
else:
|
||||
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
||||
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||
image_inputs["image_grid_thw"][0][0] = 0
|
||||
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
r"""
|
||||
Computes the real sequence length after truncation by the cutoff_len.
|
||||
|
@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import (
|
||||
get_paligemma_token_type_ids,
|
||||
get_pixel_values,
|
||||
get_qwen2vl_image_inputs,
|
||||
greedy_knapsack,
|
||||
infer_seqlen,
|
||||
)
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@ -43,41 +36,15 @@ def _encode_supervised_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
images: Sequence["ImageObject"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
image_processor = getattr(processor, "image_processor")
|
||||
merge_length = image_processor.merge_size**2
|
||||
if len(images) > 0:
|
||||
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
||||
index = 0
|
||||
for message in prompt:
|
||||
content = message["content"]
|
||||
while "<|image_pad|>" in content:
|
||||
content = content.replace(
|
||||
"<|image_pad|>",
|
||||
template.vision_start_token
|
||||
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
|
||||
+ template.vision_end_token,
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
|
||||
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
|
||||
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
@ -125,28 +92,21 @@ def preprocess_supervised_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
model_inputs["image_grid_thw"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@ -157,15 +117,12 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
|
||||
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
|
||||
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
|
||||
else:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset(
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset(
|
||||
batch_labels.append(labels)
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
|
@ -12,11 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -39,9 +40,6 @@ def _encode_unsupervised_example(
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
if len(response) == 1:
|
||||
messages = prompt + response
|
||||
else:
|
||||
@ -51,10 +49,7 @@ def _encode_unsupervised_example(
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
@ -69,19 +64,15 @@ def preprocess_unsupervised_dataset(
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
@ -93,10 +84,12 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
|
@ -15,9 +15,11 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import BasePlugin, get_mm_plugin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -41,11 +43,9 @@ class Template:
|
||||
format_prefix: "Formatter"
|
||||
default_system: str
|
||||
stop_words: List[str]
|
||||
image_token: str
|
||||
vision_start_token: str
|
||||
vision_end_token: str
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
self,
|
||||
@ -207,11 +207,9 @@ def _register_template(
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: Sequence[str] = [],
|
||||
image_token: str = "<image>",
|
||||
vision_start_token: str = "<|vision_start|>",
|
||||
vision_end_token: str = "<|vision_end|>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
|
||||
) -> None:
|
||||
r"""
|
||||
Registers a chat template.
|
||||
@ -258,11 +256,9 @@ def _register_template(
|
||||
format_prefix=format_prefix or default_prefix_formatter,
|
||||
default_system=default_system,
|
||||
stop_words=stop_words,
|
||||
image_token=image_token,
|
||||
vision_start_token=vision_start_token,
|
||||
vision_end_token=vision_end_token,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
|
||||
@ -722,6 +718,17 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
@ -766,6 +773,19 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="paligemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
||||
),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
@ -790,17 +810,15 @@ _register_template(
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen2vl",
|
||||
name="qwen2_vl",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
image_token="<|image_pad|>",
|
||||
vision_start_token="<|vision_start|>",
|
||||
vision_end_token="<|vision_end|>",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
@ -915,6 +933,7 @@ _register_template(
|
||||
),
|
||||
stop_words=["###"],
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -47,6 +47,8 @@ FILEEXT2TYPE = {
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
IMAGE_PLACEHOLDER = "<image>"
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
@ -785,7 +787,7 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
||||
},
|
||||
},
|
||||
template="vicuna",
|
||||
template="llava",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
@ -930,27 +932,28 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"PaliGemma-3B-pt-224": {
|
||||
"PaliGemma-3B-pt-224-Chat": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
|
||||
},
|
||||
"PaliGemma-3B-pt-448": {
|
||||
"PaliGemma-3B-pt-448-Chat": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
|
||||
},
|
||||
"PaliGemma-3B-pt-896": {
|
||||
"PaliGemma-3B-pt-896-Chat": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
|
||||
},
|
||||
"PaliGemma-3B-mix-224": {
|
||||
"PaliGemma-3B-mix-224-Chat": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
|
||||
},
|
||||
"PaliGemma-3B-mix-448": {
|
||||
"PaliGemma-3B-mix-448-Chat": {
|
||||
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
|
||||
},
|
||||
},
|
||||
template="paligemma",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
@ -1329,6 +1332,34 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2VL-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
|
||||
},
|
||||
"Qwen2VL-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
|
||||
},
|
||||
"Qwen2VL-2B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2VL-2B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
|
||||
},
|
||||
"Qwen2VL-7B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
|
||||
},
|
||||
"Qwen2VL-7B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
|
||||
},
|
||||
},
|
||||
template="qwen2_vl",
|
||||
vision=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"SOLAR-10.7B": {
|
||||
|
@ -79,7 +79,7 @@ def check_dependencies() -> None:
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
|
@ -117,7 +117,7 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
|
||||
)
|
||||
use_liger_kernel: bool = field(
|
||||
enable_liger_kernel: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
|
@ -116,7 +116,7 @@ def _check_extra_dependencies(
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
if model_args.use_liger_kernel:
|
||||
if model_args.enable_liger_kernel:
|
||||
require_version("liger-kernel", "To fix: pip install liger-kernel")
|
||||
|
||||
if model_args.mixture_of_depths is not None:
|
||||
|
@ -27,7 +27,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.use_liger_kernel:
|
||||
if not is_trainable or not model_args.enable_liger_kernel:
|
||||
return
|
||||
|
||||
if getattr(config, "model_type", None) == "gemma":
|
||||
|
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
@ -36,11 +36,14 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("output")
|
||||
elif model.config.model_type in ["llava", "paligemma"]:
|
||||
forbidden_modules.add("multi_modal_projector")
|
||||
elif model.config.model_type in ["qwen2_vl"]:
|
||||
elif model.config.model_type == "qwen2_vl":
|
||||
forbidden_modules.add("merger")
|
||||
|
||||
if freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
if model.config.model_type == "qwen2_vl":
|
||||
forbidden_modules.add("visual")
|
||||
else:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
|
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
|
||||
require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "image_grid_thw" in batch:
|
||||
model_inputs["image_grid_thw"] = batch["image_grid_thw"]
|
||||
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
|
@ -17,9 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from ...data import get_dataset
|
||||
from ...data import CustomDataCollatorForSeq2Seq, get_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
@ -47,7 +45,7 @@ def run_ppo(
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
|
||||
|
@ -115,7 +115,7 @@ class Runner:
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
use_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
visual_inputs=get("top.visual_inputs"),
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
|
Loading…
x
Reference in New Issue
Block a user