mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
		
							parent
							
								
									77c2c7076b
								
							
						
					
					
						commit
						c62a6ca59d
					
				
							
								
								
									
										15
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								README.md
									
									
									
									
									
								
							@ -47,7 +47,7 @@ Choose your path:
 | 
			
		||||
 | 
			
		||||
## Features
 | 
			
		||||
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
 | 
			
		||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
 | 
			
		||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
 | 
			
		||||
- **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
 | 
			
		||||
@ -72,14 +72,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `use_liger_kernel: true` for efficient training.
 | 
			
		||||
[24/08/30] We supported fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models.
 | 
			
		||||
 | 
			
		||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
 | 
			
		||||
 | 
			
		||||
[24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR.
 | 
			
		||||
 | 
			
		||||
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
 | 
			
		||||
 | 
			
		||||
<details><summary>Full Changelog</summary>
 | 
			
		||||
 | 
			
		||||
[24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR.
 | 
			
		||||
 | 
			
		||||
[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
 | 
			
		||||
 | 
			
		||||
[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
 | 
			
		||||
@ -172,14 +174,15 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -         |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2    |
 | 
			
		||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama)            | 8B/70B                           | llama3    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | vicuna    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | llava     |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B                            | cpm       |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)               | 7B/8x7B/8x22B                    | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                            | 1B/7B                            | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | gemma     |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | paligemma |
 | 
			
		||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft)                 | 1.3B/2.7B                        | -         |
 | 
			
		||||
| [Phi-3](https://huggingface.co/microsoft)                         | 4B/7B/14B                        | phi       |
 | 
			
		||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen      |
 | 
			
		||||
| [Qwen2-VL](https://huggingface.co/Qwen)                           | 2B/7B                            | qwen2_vl  |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -         |
 | 
			
		||||
| [XVERSE](https://huggingface.co/xverse)                           | 7B/13B/65B                       | xverse    |
 | 
			
		||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai)                         | 6B/9B/34B                        | yi        |
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								README_zh.md
									
									
									
									
									
								
							@ -48,7 +48,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
## 项目特色
 | 
			
		||||
 | 
			
		||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
 | 
			
		||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
 | 
			
		||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
 | 
			
		||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
 | 
			
		||||
- **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
 | 
			
		||||
@ -73,14 +73,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `use_liger_kernel: true` 来加速训练。
 | 
			
		||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
 | 
			
		||||
 | 
			
		||||
[24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。
 | 
			
		||||
 | 
			
		||||
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。
 | 
			
		||||
 | 
			
		||||
[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
 | 
			
		||||
[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
 | 
			
		||||
@ -173,14 +175,15 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -         |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2    |
 | 
			
		||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama)            | 8B/70B                           | llama3    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | vicuna    |
 | 
			
		||||
| [LLaVA-1.5](https://huggingface.co/llava-hf)                      | 7B/13B                           | llava     |
 | 
			
		||||
| [MiniCPM](https://huggingface.co/openbmb)                         | 1B/2B                            | cpm       |
 | 
			
		||||
| [Mistral/Mixtral](https://huggingface.co/mistralai)               | 7B/8x7B/8x22B                    | mistral   |
 | 
			
		||||
| [OLMo](https://huggingface.co/allenai)                            | 1B/7B                            | -         |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | gemma     |
 | 
			
		||||
| [PaliGemma](https://huggingface.co/google)                        | 3B                               | paligemma |
 | 
			
		||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft)                 | 1.3B/2.7B                        | -         |
 | 
			
		||||
| [Phi-3](https://huggingface.co/microsoft)                         | 4B/7B/14B                        | phi       |
 | 
			
		||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen      |
 | 
			
		||||
| [Qwen2-VL](https://huggingface.co/Qwen)                           | 2B/7B                            | qwen2_vl  |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -         |
 | 
			
		||||
| [XVERSE](https://huggingface.co/xverse)                           | 7B/13B/65B                       | xverse    |
 | 
			
		||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai)                         | 6B/9B/34B                        | yi        |
 | 
			
		||||
 | 
			
		||||
@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Reward Modeling
 | 
			
		||||
 | 
			
		||||
@ -33,6 +33,7 @@ llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
 | 
			
		||||
llamafactory-cli train examples/train_lora/qwen2vl_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 奖励模型训练
 | 
			
		||||
 | 
			
		||||
@ -1,40 +0,0 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
deepspeed: examples/deepspeed/ds_z3_config.json
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: qwen2vl_demo
 | 
			
		||||
template: qwen2vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/qwen2-vl-7b/full/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 1
 | 
			
		||||
learning_rate: 1.0e-5
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
val_size: 0.1
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
eval_strategy: steps
 | 
			
		||||
eval_steps: 500
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
visual_inputs: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
@ -9,23 +9,23 @@ finetuning_type: lora
 | 
			
		||||
lora_target: all
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: qwen2vl_demo
 | 
			
		||||
template: qwen2vl
 | 
			
		||||
dataset: mllm_demo
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/qwen2-vl-7b/lora/sft
 | 
			
		||||
output_dir: saves/qwen2_vl-7b/lora/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 2
 | 
			
		||||
gradient_accumulation_steps: 1
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 1.0e-4
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
transformers>=4.41.2,<=4.43.4
 | 
			
		||||
transformers>=4.41.2,<=4.45.0
 | 
			
		||||
datasets>=2.16.0,<=2.21.0
 | 
			
		||||
accelerate>=0.30.1,<=0.33.0
 | 
			
		||||
peft>=0.11.1,<=0.12.0
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ Level:
 | 
			
		||||
 | 
			
		||||
Dependency graph:
 | 
			
		||||
  main:
 | 
			
		||||
    transformers>=4.41.2,<=4.43.4
 | 
			
		||||
    transformers>=4.41.2,<=4.44.3
 | 
			
		||||
    datasets>=2.16.0,<=2.21.0
 | 
			
		||||
    accelerate>=0.30.1,<=0.33.0
 | 
			
		||||
    peft>=0.11.1,<=0.12.0
 | 
			
		||||
@ -28,9 +28,9 @@ Dependency graph:
 | 
			
		||||
  attention:
 | 
			
		||||
    transformers>=4.42.4 (gemma+fa2)
 | 
			
		||||
  longlora:
 | 
			
		||||
    transformers>=4.41.2,<=4.43.4
 | 
			
		||||
    transformers>=4.41.2,<=4.44.3
 | 
			
		||||
  packing:
 | 
			
		||||
    transformers>=4.41.2,<=4.43.4
 | 
			
		||||
    transformers>=4.41.2,<=4.44.3
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from .cli import VERSION
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ import torch
 | 
			
		||||
from transformers import GenerationConfig, TextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from ..extras.misc import get_logits_processor
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
@ -31,7 +32,6 @@ from .base_engine import BaseEngine, Response
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
    from trl import PreTrainedModelWrapper
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
@ -81,27 +81,19 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        image: Optional["NDArray"] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Tuple[Dict[str, Any], int]:
 | 
			
		||||
        if (
 | 
			
		||||
            processor is not None
 | 
			
		||||
            and image is not None
 | 
			
		||||
            and not hasattr(processor, "image_seq_length")
 | 
			
		||||
            and template.image_token not in messages[0]["content"]
 | 
			
		||||
        ):  # llava-like models
 | 
			
		||||
            messages[0]["content"] = template.image_token + messages[0]["content"]
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            if IMAGE_PLACEHOLDER not in messages[0]["content"]:
 | 
			
		||||
                messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
            messages = template.mm_plugin.process_messages(messages, [image], processor)
 | 
			
		||||
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or generating_args["default_system"]
 | 
			
		||||
        pixel_values = None
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(
 | 
			
		||||
            tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
 | 
			
		||||
        )
 | 
			
		||||
        if processor is not None and image is not None:  # add image features
 | 
			
		||||
            image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
            batch_feature = image_processor(image, return_tensors="pt")
 | 
			
		||||
            pixel_values = batch_feature.to(model.device)["pixel_values"]  # shape (B, C, H, W)
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
 | 
			
		||||
                prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
 | 
			
		||||
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
        inputs = torch.tensor([prompt_ids], device=model.device)
 | 
			
		||||
@ -164,8 +156,13 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            logits_processor=get_logits_processor(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if pixel_values is not None:
 | 
			
		||||
            gen_kwargs["pixel_values"] = pixel_values
 | 
			
		||||
        if image is not None:
 | 
			
		||||
            mm_inputs = template.mm_plugin.get_mm_inputs(
 | 
			
		||||
                images=[image], feature_seqlens={"token_type_ids": prompt_length}, processor=processor
 | 
			
		||||
            )
 | 
			
		||||
            for key, value in mm_inputs.items():
 | 
			
		||||
                value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
 | 
			
		||||
                gen_kwargs[key] = value.to(model.device)
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,13 +12,19 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
 | 
			
		||||
from .collator import (
 | 
			
		||||
    CustomDataCollatorForSeq2Seq,
 | 
			
		||||
    KTODataCollatorWithPadding,
 | 
			
		||||
    PairwiseDataCollatorWithPadding,
 | 
			
		||||
    SFTDataCollatorWith4DAttentionMask,
 | 
			
		||||
)
 | 
			
		||||
from .data_utils import Role, split_dataset
 | 
			
		||||
from .loader import get_dataset
 | 
			
		||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "CustomDataCollatorForSeq2Seq",
 | 
			
		||||
    "KTODataCollatorWithPadding",
 | 
			
		||||
    "PairwiseDataCollatorWithPadding",
 | 
			
		||||
    "SFTDataCollatorWith4DAttentionMask",
 | 
			
		||||
 | 
			
		||||
@ -62,15 +62,11 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
 | 
			
		||||
class CustomDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for 4d attention mask.
 | 
			
		||||
    Data collator for custom models (like Qwen2-VL).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    block_diag_attn: bool = False
 | 
			
		||||
    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
 | 
			
		||||
    compute_dtype: "torch.dtype" = torch.float32
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_grid_thw = None
 | 
			
		||||
        if "image_grid_thw" in features[0]:
 | 
			
		||||
@ -83,23 +79,18 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
 | 
			
		||||
                torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
 | 
			
		||||
            ]
 | 
			
		||||
            if image_grid_thw_list:
 | 
			
		||||
                image_grid_thw = torch.cat(image_grid_thw_list, 0)
 | 
			
		||||
                image_grid_thw = torch.cat(image_grid_thw_list, dim=0)
 | 
			
		||||
                pixel_values = torch.cat(pixel_values_list, dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                # Handle the case where the list is empty, for example:
 | 
			
		||||
                image_grid_thw = None
 | 
			
		||||
            if pixel_values_list:
 | 
			
		||||
                pixel_values = torch.cat(pixel_values_list, 0)
 | 
			
		||||
            else:
 | 
			
		||||
                # Handle the case where the list is empty, for example:
 | 
			
		||||
                pixel_values = None
 | 
			
		||||
 | 
			
		||||
            features = [
 | 
			
		||||
                {key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
 | 
			
		||||
                for feature in features
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        features = super().__call__(features)
 | 
			
		||||
        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
 | 
			
		||||
            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
 | 
			
		||||
        if image_grid_thw is not None:
 | 
			
		||||
            features["image_grid_thw"] = image_grid_thw
 | 
			
		||||
            features["pixel_values"] = pixel_values
 | 
			
		||||
@ -108,7 +99,25 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
class SFTDataCollatorWith4DAttentionMask(CustomDataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for 4d attention mask.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    block_diag_attn: bool = False
 | 
			
		||||
    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
 | 
			
		||||
    compute_dtype: "torch.dtype" = torch.float32
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        features = super().__call__(features)
 | 
			
		||||
        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
 | 
			
		||||
            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
 | 
			
		||||
 | 
			
		||||
        return features
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PairwiseDataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for pairwise data.
 | 
			
		||||
    """
 | 
			
		||||
@ -128,9 +137,12 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
                    "attention_mask": feature["{}_attention_mask".format(key)],
 | 
			
		||||
                    "labels": feature["{}_labels".format(key)],
 | 
			
		||||
                }
 | 
			
		||||
                if "pixel_values" in feature:
 | 
			
		||||
                if "pixel_values" in feature:  # image data are same for chosen and rejected
 | 
			
		||||
                    target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
                if "image_grid_thw" in feature:
 | 
			
		||||
                    target_feature["image_grid_thw"] = feature["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
                if "{}_token_type_ids".format(key) in feature:
 | 
			
		||||
                    target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)]
 | 
			
		||||
 | 
			
		||||
@ -140,7 +152,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
class KTODataCollatorWithPadding(CustomDataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for KTO data.
 | 
			
		||||
    """
 | 
			
		||||
@ -163,6 +175,9 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
            if "pixel_values" in feature:
 | 
			
		||||
                target_feature["pixel_values"] = feature["pixel_values"]
 | 
			
		||||
 | 
			
		||||
            if "image_grid_thw" in feature:
 | 
			
		||||
                target_feature["image_grid_thw"] = feature["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
            if "token_type_ids" in feature:
 | 
			
		||||
                target_feature["token_type_ids"] = feature["token_type_ids"]
 | 
			
		||||
                kl_feature["token_type_ids"] = feature["kl_token_type_ids"]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										271
									
								
								src/llamafactory/data/mm_plugin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										271
									
								
								src/llamafactory/data/mm_plugin.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,271 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from PIL.Image import Image
 | 
			
		||||
from transformers import ProcessorMixin
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    import torch
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "torch.Tensor":
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes visual inputs. (currently only supports a single image)
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        pixel_values: tensor with shape (B, C, H, W)
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
 | 
			
		||||
    return image_processor([image], return_tensors="pt")["pixel_values"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[List[int]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets paligemma token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        token_type_ids: shape (1, seq_len)
 | 
			
		||||
    """
 | 
			
		||||
    image_seq_length = getattr(processor, "image_seq_length")
 | 
			
		||||
    return [[0] * image_seq_length + [1] * (input_len - image_seq_length)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_qwen2vl_image_inputs(
 | 
			
		||||
    images: Sequence["ImageObject"], processor: "ProcessorMixin"
 | 
			
		||||
) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes qwen2-vl visual inputs. Supports multiple images.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        pixel_values: tensor with shape (num_patches, patch_dim)
 | 
			
		||||
        image_grid_thw: tensot with shape (num_images, 3), where the three numbers are time, width, height
 | 
			
		||||
 | 
			
		||||
    It holds num_patches == torch.prod(image_grid_thw)
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    if len(images) != 0:
 | 
			
		||||
        image_inputs = image_processor(images=images, return_tensors="pt")
 | 
			
		||||
    else:
 | 
			
		||||
        image = Image.new("RGB", (56, 56), (255, 255, 255))
 | 
			
		||||
        image_inputs = image_processor(images=[image], return_tensors="pt")
 | 
			
		||||
        image_inputs["image_grid_thw"][0][0] = 0  # fake image
 | 
			
		||||
 | 
			
		||||
    return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasePlugin:
 | 
			
		||||
    def __init__(self, image_token: str) -> None:
 | 
			
		||||
        self.image_token = image_token
 | 
			
		||||
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: List[int],
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    def process_model_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        model_inputs: Dict[str, List[Any]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlavaPlugin(BasePlugin):
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        image_count = 0
 | 
			
		||||
        new_messages = []
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_count += 1
 | 
			
		||||
                if image_count > 1:
 | 
			
		||||
                    raise ValueError("Llava model only accepts one image per sample.")
 | 
			
		||||
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, self.image_token, 1)
 | 
			
		||||
 | 
			
		||||
            new_messages.append({"role": message["role"], "content": content})
 | 
			
		||||
 | 
			
		||||
        return new_messages
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        return {"pixel_values": get_pixel_values(images, processor)}
 | 
			
		||||
 | 
			
		||||
    def process_model_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        model_inputs: Dict[str, List[Any]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
 | 
			
		||||
        model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        image_count = 0
 | 
			
		||||
        new_messages = []
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_count += 1
 | 
			
		||||
                if image_count > 1:
 | 
			
		||||
                    raise ValueError("PaliGemma model only accepts one image per sample.")
 | 
			
		||||
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "", 1)
 | 
			
		||||
 | 
			
		||||
            new_messages.append({"role": message["role"], "content": content})
 | 
			
		||||
 | 
			
		||||
        return new_messages
 | 
			
		||||
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: List[int],
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        image_seq_length: int = getattr(image_processor, "image_seq_length")
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
        input_ids = [image_token_id] * image_seq_length + input_ids
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            labels = [IGNORE_INDEX] * image_seq_length + labels
 | 
			
		||||
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        mm_inputs = {"pixel_values": get_pixel_values(images, processor)}
 | 
			
		||||
        for feature_name, feature_length in feature_seqlens.items():
 | 
			
		||||
            mm_inputs[feature_name] = get_paligemma_token_type_ids(feature_length, processor)
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    def process_model_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        model_inputs: Dict[str, List[Any]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
 | 
			
		||||
        model_inputs["pixel_values"].append(mm_inputs["pixel_values"][0])
 | 
			
		||||
        for feature_name in feature_seqlens.keys():
 | 
			
		||||
            model_inputs[feature_name].append(mm_inputs[feature_name][0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        if len(images) > 0:
 | 
			
		||||
            image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
        index = 0
 | 
			
		||||
        new_messages = []
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    "<|vision_start|>{}<|vision_end|>".format(
 | 
			
		||||
                        self.image_token * (image_grid_thw[index].prod() // merge_length)
 | 
			
		||||
                    ),
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                index += 1
 | 
			
		||||
 | 
			
		||||
            new_messages.append({"role": message["role"], "content": content})
 | 
			
		||||
 | 
			
		||||
        return new_messages
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        return get_qwen2vl_image_inputs(images, processor)
 | 
			
		||||
 | 
			
		||||
    def process_model_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        model_inputs: Dict[str, List[Any]],
 | 
			
		||||
        images: Sequence["ImageObject"],
 | 
			
		||||
        feature_seqlens: Dict[str, int],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        mm_inputs = self.get_mm_inputs(images, feature_seqlens, processor)
 | 
			
		||||
        model_inputs["pixel_values"].append(mm_inputs["pixel_values"])
 | 
			
		||||
        model_inputs["image_grid_thw"].append(mm_inputs["image_grid_thw"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
    "llava": LlavaPlugin,
 | 
			
		||||
    "paligemma": PaliGemmaPlugin,
 | 
			
		||||
    "qwen2_vl": Qwen2vlPlugin,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mm_plugin(name: str, image_token: str) -> "BasePlugin":
 | 
			
		||||
    if name not in PLUGINS:
 | 
			
		||||
        raise ValueError("{} not found.".format(name))
 | 
			
		||||
 | 
			
		||||
    return PLUGINS[name](image_token)
 | 
			
		||||
@ -12,11 +12,12 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
 | 
			
		||||
from .processor_utils import infer_seqlen
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -40,9 +41,6 @@ def _encode_feedback_example(
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
 | 
			
		||||
    if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
        prompt[0]["content"] = template.image_token + prompt[0]["content"]
 | 
			
		||||
 | 
			
		||||
    if response[0]["content"]:  # desired example
 | 
			
		||||
        kto_tag = True
 | 
			
		||||
        messages = prompt + [response[0]]
 | 
			
		||||
@ -62,10 +60,8 @@ def _encode_feedback_example(
 | 
			
		||||
        response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        kl_response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
 | 
			
		||||
        prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
        kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
 | 
			
		||||
    kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
 | 
			
		||||
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -91,28 +87,15 @@ def preprocess_feedback_dataset(
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
 | 
			
		||||
    kl_response = examples["response"][::-1]
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "input_ids": [],
 | 
			
		||||
        "attention_mask": [],
 | 
			
		||||
        "labels": [],
 | 
			
		||||
        "kl_input_ids": [],
 | 
			
		||||
        "kl_attention_mask": [],
 | 
			
		||||
        "kl_labels": [],
 | 
			
		||||
        "kto_tags": [],
 | 
			
		||||
    }
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
            model_inputs["kl_token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
 | 
			
		||||
        input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
 | 
			
		||||
            prompt=examples["prompt"][i],
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            response=examples["response"][i],
 | 
			
		||||
            kl_response=kl_response[i],
 | 
			
		||||
            system=examples["system"][i],
 | 
			
		||||
@ -129,11 +112,15 @@ def preprocess_feedback_dataset(
 | 
			
		||||
        model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
 | 
			
		||||
        model_inputs["kl_labels"].append(kl_labels)
 | 
			
		||||
        model_inputs["kto_tags"].append(kto_tag)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
                model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
 | 
			
		||||
        template.mm_plugin.process_model_inputs(
 | 
			
		||||
            model_inputs=model_inputs,
 | 
			
		||||
            images=examples["images"][i],
 | 
			
		||||
            feature_seqlens={
 | 
			
		||||
                "token_type_ids": len(input_ids),
 | 
			
		||||
                "kl_token_type_ids": len(kl_input_ids),
 | 
			
		||||
            },
 | 
			
		||||
            processor=processor,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
 | 
			
		||||
    undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
 | 
			
		||||
 | 
			
		||||
@ -12,11 +12,12 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
 | 
			
		||||
from .processor_utils import infer_seqlen
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -39,9 +40,6 @@ def _encode_pairwise_example(
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
 | 
			
		||||
    if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
        prompt[0]["content"] = template.image_token + prompt[0]["content"]
 | 
			
		||||
 | 
			
		||||
    chosen_messages = prompt + [response[0]]
 | 
			
		||||
    rejected_messages = prompt + [response[1]]
 | 
			
		||||
    prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
 | 
			
		||||
@ -51,10 +49,7 @@ def _encode_pairwise_example(
 | 
			
		||||
        chosen_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        rejected_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
 | 
			
		||||
        prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
 | 
			
		||||
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
 | 
			
		||||
    # consider the response is more important
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -77,27 +72,15 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
 | 
			
		||||
    model_inputs = {
 | 
			
		||||
        "chosen_input_ids": [],
 | 
			
		||||
        "chosen_attention_mask": [],
 | 
			
		||||
        "chosen_labels": [],
 | 
			
		||||
        "rejected_input_ids": [],
 | 
			
		||||
        "rejected_attention_mask": [],
 | 
			
		||||
        "rejected_labels": [],
 | 
			
		||||
    }
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["chosen_token_type_ids"] = []
 | 
			
		||||
            model_inputs["rejected_token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
 | 
			
		||||
        chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
 | 
			
		||||
            prompt=examples["prompt"][i],
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            response=examples["response"][i],
 | 
			
		||||
            system=examples["system"][i],
 | 
			
		||||
            tools=examples["tools"][i],
 | 
			
		||||
@ -112,15 +95,15 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
        model_inputs["rejected_input_ids"].append(rejected_input_ids)
 | 
			
		||||
        model_inputs["rejected_attention_mask"].append([1] * len(rejected_input_ids))
 | 
			
		||||
        model_inputs["rejected_labels"].append(rejected_labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["chosen_token_type_ids"].append(
 | 
			
		||||
                    get_paligemma_token_type_ids(len(chosen_input_ids), processor)
 | 
			
		||||
                )
 | 
			
		||||
                model_inputs["rejected_token_type_ids"].append(
 | 
			
		||||
                    get_paligemma_token_type_ids(len(rejected_input_ids), processor)
 | 
			
		||||
                )
 | 
			
		||||
        template.mm_plugin.process_model_inputs(
 | 
			
		||||
            model_inputs=model_inputs,
 | 
			
		||||
            images=examples["images"][i],
 | 
			
		||||
            feature_seqlens={
 | 
			
		||||
                "chosen_token_type_ids": len(chosen_input_ids),
 | 
			
		||||
                "rejected_token_type_ids": len(rejected_input_ids),
 | 
			
		||||
            },
 | 
			
		||||
            processor=processor,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,20 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import bisect
 | 
			
		||||
from typing import TYPE_CHECKING, List, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_pillow_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
from typing import List, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
 | 
			
		||||
@ -61,37 +48,6 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
 | 
			
		||||
    return knapsacks
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_pixel_values(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes visual inputs. (currently only supports a single image)
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    image = images[0] if len(images) != 0 else Image.new("RGB", (100, 100), (255, 255, 255))
 | 
			
		||||
    return image_processor(image, return_tensors="pt")["pixel_values"][0]  # shape (C, H, W)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") -> List[int]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets paligemma token type ids for computing loss.
 | 
			
		||||
    """
 | 
			
		||||
    image_seq_length = getattr(processor, "image_seq_length")
 | 
			
		||||
    return [0] * image_seq_length + [1] * (input_len - image_seq_length)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
 | 
			
		||||
    r"""
 | 
			
		||||
    Processes visual inputs. support multi images
 | 
			
		||||
    """
 | 
			
		||||
    image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
    if len(images) != 0:
 | 
			
		||||
        image_inputs = image_processor(images=images, return_tensors="pt")
 | 
			
		||||
    else:
 | 
			
		||||
        image = Image.new("RGB", (56, 56), (255, 255, 255))
 | 
			
		||||
        image_inputs = image_processor(images=[image], return_tensors="pt")
 | 
			
		||||
        image_inputs["image_grid_thw"][0][0] = 0
 | 
			
		||||
    return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes the real sequence length after truncation by the cutoff_len.
 | 
			
		||||
 | 
			
		||||
@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from .processor_utils import (
 | 
			
		||||
    get_paligemma_token_type_ids,
 | 
			
		||||
    get_pixel_values,
 | 
			
		||||
    get_qwen2vl_image_inputs,
 | 
			
		||||
    greedy_knapsack,
 | 
			
		||||
    infer_seqlen,
 | 
			
		||||
)
 | 
			
		||||
from .processor_utils import greedy_knapsack, infer_seqlen
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
@ -43,41 +36,15 @@ def _encode_supervised_example(
 | 
			
		||||
    system: Optional[str],
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    images: Sequence["ImageObject"],
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
    train_on_prompt: bool,
 | 
			
		||||
    mask_history: bool,
 | 
			
		||||
) -> Tuple[List[int], List[int]]:
 | 
			
		||||
    if processor is not None and "image_grid_thw" in processor.model_input_names:  # qwen2_vl models
 | 
			
		||||
        image_processor = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length = image_processor.merge_size**2
 | 
			
		||||
        if len(images) > 0:
 | 
			
		||||
            image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
 | 
			
		||||
        index = 0
 | 
			
		||||
        for message in prompt:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while "<|image_pad|>" in content:
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    "<|image_pad|>",
 | 
			
		||||
                    template.vision_start_token
 | 
			
		||||
                    + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
 | 
			
		||||
                    + template.vision_end_token,
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                index += 1
 | 
			
		||||
            message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
 | 
			
		||||
    elif processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
        prompt[0]["content"] = template.image_token + prompt[0]["content"]
 | 
			
		||||
 | 
			
		||||
    messages = prompt + response
 | 
			
		||||
    input_ids, labels = [], []
 | 
			
		||||
 | 
			
		||||
    if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
 | 
			
		||||
        input_ids += [image_token_id] * getattr(processor, "image_seq_length")
 | 
			
		||||
        labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
 | 
			
		||||
    input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
 | 
			
		||||
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
 | 
			
		||||
    total_length = 1 if template.efficient_eos else 0
 | 
			
		||||
@ -125,28 +92,21 @@ def preprocess_supervised_dataset(
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
) -> Dict[str, List[Any]]:
 | 
			
		||||
    # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 | 
			
		||||
    # for multiturn examples, we only mask the prompt part in each prompt-response pair.
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
        if "image_grid_thw" in processor.model_input_names:  # qwen2_vl models
 | 
			
		||||
            model_inputs["image_grid_thw"] = []
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
 | 
			
		||||
        input_ids, labels = _encode_supervised_example(
 | 
			
		||||
            prompt=examples["prompt"][i],
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            response=examples["response"][i],
 | 
			
		||||
            system=examples["system"][i],
 | 
			
		||||
            tools=examples["tools"][i],
 | 
			
		||||
            images=examples["images"][i],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -157,15 +117,12 @@ def preprocess_supervised_dataset(
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            if "image_grid_thw" in processor.model_input_names:  # qwen2_vl models
 | 
			
		||||
                image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
 | 
			
		||||
                model_inputs["pixel_values"].append(image_inputs["pixel_values"])
 | 
			
		||||
                model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
 | 
			
		||||
            else:
 | 
			
		||||
                model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
                if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                    model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
        template.mm_plugin.process_model_inputs(
 | 
			
		||||
            model_inputs=model_inputs,
 | 
			
		||||
            images=examples["images"][i],
 | 
			
		||||
            feature_seqlens={"token_type_ids": len(input_ids)},
 | 
			
		||||
            processor=processor,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
) -> Dict[str, List[Any]]:
 | 
			
		||||
    # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
 | 
			
		||||
    # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
 | 
			
		||||
    valid_num = 0
 | 
			
		||||
@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
            batch_labels.append(labels)
 | 
			
		||||
            valid_num += 1
 | 
			
		||||
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1)  # reserved for the padding token
 | 
			
		||||
    for knapsack in knapsacks:
 | 
			
		||||
        packed_input_ids, packed_attention_masks, packed_labels = [], [], []
 | 
			
		||||
 | 
			
		||||
@ -12,11 +12,12 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ..data_utils import Role
 | 
			
		||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
 | 
			
		||||
from .processor_utils import infer_seqlen
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -39,9 +40,6 @@ def _encode_unsupervised_example(
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
) -> Tuple[List[int], List[int]]:
 | 
			
		||||
    if processor is not None and not hasattr(processor, "image_seq_length"):  # llava-like models
 | 
			
		||||
        prompt[0]["content"] = template.image_token + prompt[0]["content"]
 | 
			
		||||
 | 
			
		||||
    if len(response) == 1:
 | 
			
		||||
        messages = prompt + response
 | 
			
		||||
    else:
 | 
			
		||||
@ -51,10 +49,7 @@ def _encode_unsupervised_example(
 | 
			
		||||
    if template.efficient_eos:
 | 
			
		||||
        labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    if processor is not None and hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
 | 
			
		||||
        input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
 | 
			
		||||
 | 
			
		||||
    input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, tokenizer, processor)
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
 | 
			
		||||
    input_ids = input_ids[:source_len]
 | 
			
		||||
    labels = labels[:target_len]
 | 
			
		||||
@ -69,19 +64,15 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Dict[str, List[List[int]]]:
 | 
			
		||||
    # build inputs with format `<bos> X` and labels with format `Y <eos>`
 | 
			
		||||
    model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        model_inputs["pixel_values"] = []
 | 
			
		||||
        if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
            model_inputs["token_type_ids"] = []
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["prompt"])):
 | 
			
		||||
        if len(examples["prompt"][i]) % 2 != 1:
 | 
			
		||||
            logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
 | 
			
		||||
        input_ids, labels = _encode_unsupervised_example(
 | 
			
		||||
            prompt=examples["prompt"][i],
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            response=examples["response"][i],
 | 
			
		||||
            system=examples["system"][i],
 | 
			
		||||
            tools=examples["tools"][i],
 | 
			
		||||
@ -93,10 +84,12 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
        model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
        model_inputs["attention_mask"].append([1] * len(input_ids))
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        if processor is not None:
 | 
			
		||||
            model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
 | 
			
		||||
            if hasattr(processor, "image_seq_length"):  # paligemma models
 | 
			
		||||
                model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
 | 
			
		||||
        template.mm_plugin.process_model_inputs(
 | 
			
		||||
            model_inputs=model_inputs,
 | 
			
		||||
            images=examples["images"][i],
 | 
			
		||||
            feature_seqlens={"token_type_ids": len(input_ids)},
 | 
			
		||||
            processor=processor,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,11 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from .data_utils import Role
 | 
			
		||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
 | 
			
		||||
from .mm_plugin import BasePlugin, get_mm_plugin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -41,11 +43,9 @@ class Template:
 | 
			
		||||
    format_prefix: "Formatter"
 | 
			
		||||
    default_system: str
 | 
			
		||||
    stop_words: List[str]
 | 
			
		||||
    image_token: str
 | 
			
		||||
    vision_start_token: str
 | 
			
		||||
    vision_end_token: str
 | 
			
		||||
    efficient_eos: bool
 | 
			
		||||
    replace_eos: bool
 | 
			
		||||
    mm_plugin: "BasePlugin"
 | 
			
		||||
 | 
			
		||||
    def encode_oneturn(
 | 
			
		||||
        self,
 | 
			
		||||
@ -207,11 +207,9 @@ def _register_template(
 | 
			
		||||
    format_prefix: Optional["Formatter"] = None,
 | 
			
		||||
    default_system: str = "",
 | 
			
		||||
    stop_words: Sequence[str] = [],
 | 
			
		||||
    image_token: str = "<image>",
 | 
			
		||||
    vision_start_token: str = "<|vision_start|>",
 | 
			
		||||
    vision_end_token: str = "<|vision_end|>",
 | 
			
		||||
    efficient_eos: bool = False,
 | 
			
		||||
    replace_eos: bool = False,
 | 
			
		||||
    mm_plugin: "BasePlugin" = BasePlugin(IMAGE_PLACEHOLDER),
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Registers a chat template.
 | 
			
		||||
@ -258,11 +256,9 @@ def _register_template(
 | 
			
		||||
        format_prefix=format_prefix or default_prefix_formatter,
 | 
			
		||||
        default_system=default_system,
 | 
			
		||||
        stop_words=stop_words,
 | 
			
		||||
        image_token=image_token,
 | 
			
		||||
        vision_start_token=vision_start_token,
 | 
			
		||||
        vision_end_token=vision_end_token,
 | 
			
		||||
        efficient_eos=efficient_eos,
 | 
			
		||||
        replace_eos=replace_eos,
 | 
			
		||||
        mm_plugin=mm_plugin,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -722,6 +718,17 @@ _register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="llava",
 | 
			
		||||
    format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
 | 
			
		||||
    default_system=(
 | 
			
		||||
        "A chat between a curious user and an artificial intelligence assistant. "
 | 
			
		||||
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
 | 
			
		||||
    ),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="mistral",
 | 
			
		||||
    format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
 | 
			
		||||
@ -766,6 +773,19 @@ _register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="paligemma",
 | 
			
		||||
    format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="phi",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
 | 
			
		||||
@ -790,17 +810,15 @@ _register_template(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="qwen2vl",
 | 
			
		||||
    name="qwen2_vl",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_separator=EmptyFormatter(slots=["\n"]),
 | 
			
		||||
    default_system="You are a helpful assistant.",
 | 
			
		||||
    image_token="<|image_pad|>",
 | 
			
		||||
    vision_start_token="<|vision_start|>",
 | 
			
		||||
    vision_end_token="<|vision_end|>",
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -915,6 +933,7 @@ _register_template(
 | 
			
		||||
    ),
 | 
			
		||||
    stop_words=["###"],
 | 
			
		||||
    efficient_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="llava", image_token="<image>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -47,6 +47,8 @@ FILEEXT2TYPE = {
 | 
			
		||||
 | 
			
		||||
IGNORE_INDEX = -100
 | 
			
		||||
 | 
			
		||||
IMAGE_PLACEHOLDER = "<image>"
 | 
			
		||||
 | 
			
		||||
LAYERNORM_NAMES = {"norm", "ln"}
 | 
			
		||||
 | 
			
		||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
 | 
			
		||||
@ -785,7 +787,7 @@ register_model_group(
 | 
			
		||||
            DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="vicuna",
 | 
			
		||||
    template="llava",
 | 
			
		||||
    vision=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -930,27 +932,28 @@ register_model_group(
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "PaliGemma-3B-pt-224": {
 | 
			
		||||
        "PaliGemma-3B-pt-224-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-pt-448": {
 | 
			
		||||
        "PaliGemma-3B-pt-448-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-pt-896": {
 | 
			
		||||
        "PaliGemma-3B-pt-896-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-mix-224": {
 | 
			
		||||
        "PaliGemma-3B-mix-224-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
 | 
			
		||||
        },
 | 
			
		||||
        "PaliGemma-3B-mix-448": {
 | 
			
		||||
        "PaliGemma-3B-mix-448-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="paligemma",
 | 
			
		||||
    vision=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -1329,6 +1332,34 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen2VL-2B-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2VL-7B-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-7B-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2VL-2B-int8-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2VL-2B-int4-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct-AWQ",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2VL-7B-int8-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2VL-7B-int4-Chat": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct-AWQ",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_vl",
 | 
			
		||||
    vision=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "SOLAR-10.7B": {
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ def check_dependencies() -> None:
 | 
			
		||||
    if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
 | 
			
		||||
        logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
    else:
 | 
			
		||||
        require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
 | 
			
		||||
        require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
 | 
			
		||||
        require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
 | 
			
		||||
        require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
 | 
			
		||||
        require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
 | 
			
		||||
 | 
			
		||||
@ -117,7 +117,7 @@ class ModelArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
 | 
			
		||||
    )
 | 
			
		||||
    use_liger_kernel: bool = field(
 | 
			
		||||
    enable_liger_kernel: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to enable liger kernel for faster training."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -116,7 +116,7 @@ def _check_extra_dependencies(
 | 
			
		||||
    if model_args.use_unsloth:
 | 
			
		||||
        require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
 | 
			
		||||
 | 
			
		||||
    if model_args.use_liger_kernel:
 | 
			
		||||
    if model_args.enable_liger_kernel:
 | 
			
		||||
        require_version("liger-kernel", "To fix: pip install liger-kernel")
 | 
			
		||||
 | 
			
		||||
    if model_args.mixture_of_depths is not None:
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable or not model_args.use_liger_kernel:
 | 
			
		||||
    if not is_trainable or not model_args.enable_liger_kernel:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    if getattr(config, "model_type", None) == "gemma":
 | 
			
		||||
 | 
			
		||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_llama_patch() -> None:
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
 | 
			
		||||
    LlamaAttention.forward = llama_attention_forward
 | 
			
		||||
    LlamaFlashAttention2.forward = llama_flash_attention_2_forward
 | 
			
		||||
    LlamaSdpaAttention.forward = llama_sdpa_attention_forward
 | 
			
		||||
 | 
			
		||||
@ -36,11 +36,14 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
 | 
			
		||||
        forbidden_modules.add("output")
 | 
			
		||||
    elif model.config.model_type in ["llava", "paligemma"]:
 | 
			
		||||
        forbidden_modules.add("multi_modal_projector")
 | 
			
		||||
    elif model.config.model_type in ["qwen2_vl"]:
 | 
			
		||||
    elif model.config.model_type == "qwen2_vl":
 | 
			
		||||
        forbidden_modules.add("merger")
 | 
			
		||||
 | 
			
		||||
    if freeze_vision_tower:
 | 
			
		||||
        forbidden_modules.add("vision_tower")
 | 
			
		||||
        if model.config.model_type == "qwen2_vl":
 | 
			
		||||
            forbidden_modules.add("visual")
 | 
			
		||||
        else:
 | 
			
		||||
            forbidden_modules.add("vision_tower")
 | 
			
		||||
 | 
			
		||||
    module_names = set()
 | 
			
		||||
    for name, module in model.named_modules():
 | 
			
		||||
 | 
			
		||||
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _patch_for_block_diag_attn(model_type: str) -> None:
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
 | 
			
		||||
    if is_transformers_version_greater_than_4_43():
 | 
			
		||||
        import transformers.modeling_flash_attention_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -130,6 +130,9 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        if "pixel_values" in batch:
 | 
			
		||||
            model_inputs["pixel_values"] = batch["pixel_values"]
 | 
			
		||||
 | 
			
		||||
        if "image_grid_thw" in batch:
 | 
			
		||||
            model_inputs["image_grid_thw"] = batch["image_grid_thw"]
 | 
			
		||||
 | 
			
		||||
        if "{}token_type_ids".format(prefix) in batch:
 | 
			
		||||
            model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -17,9 +17,7 @@
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers import DataCollatorWithPadding
 | 
			
		||||
 | 
			
		||||
from ...data import get_dataset
 | 
			
		||||
from ...data import CustomDataCollatorForSeq2Seq, get_dataset
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
from ..callbacks import fix_valuehead_checkpoint
 | 
			
		||||
@ -47,7 +45,7 @@ def run_ppo(
 | 
			
		||||
    model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
 | 
			
		||||
 | 
			
		||||
    tokenizer.padding_side = "left"  # use left-padding in generation while using right-padding in training
 | 
			
		||||
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
 | 
			
		||||
    data_collator = CustomDataCollatorForSeq2Seq(tokenizer=tokenizer)
 | 
			
		||||
 | 
			
		||||
    # Create reference model and reward model
 | 
			
		||||
    ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
 | 
			
		||||
 | 
			
		||||
@ -115,7 +115,7 @@ class Runner:
 | 
			
		||||
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
 | 
			
		||||
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
 | 
			
		||||
            use_unsloth=(get("top.booster") == "unsloth"),
 | 
			
		||||
            use_liger_kernel=(get("top.booster") == "liger_kernel"),
 | 
			
		||||
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
 | 
			
		||||
            visual_inputs=get("top.visual_inputs"),
 | 
			
		||||
            dataset_dir=get("train.dataset_dir"),
 | 
			
		||||
            dataset=",".join(get("train.dataset")),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user