mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
		
							parent
							
								
									9feb78e7b4
								
							
						
					
					
						commit
						8f401e37f8
					
				@ -76,8 +76,9 @@ Choose your path:
 | 
			
		||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc.
 | 
			
		||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
 | 
			
		||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
 | 
			
		||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
 | 
			
		||||
- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA.
 | 
			
		||||
- **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA.
 | 
			
		||||
- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc.
 | 
			
		||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc.
 | 
			
		||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
 | 
			
		||||
 | 
			
		||||
@ -105,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks.
 | 
			
		||||
 | 
			
		||||
[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model.
 | 
			
		||||
 | 
			
		||||
[25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage.
 | 
			
		||||
@ -247,6 +250,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Phi-4](https://huggingface.co/microsoft)                         | 14B                              | phi4             |
 | 
			
		||||
| [Pixtral](https://huggingface.co/mistralai)                       | 12B                              | pixtral          |
 | 
			
		||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen             |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio      |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/72B                     | qwen2_vl         |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1       |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                |
 | 
			
		||||
 | 
			
		||||
@ -78,8 +78,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。
 | 
			
		||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
 | 
			
		||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
 | 
			
		||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
 | 
			
		||||
- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。
 | 
			
		||||
- **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。
 | 
			
		||||
- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。
 | 
			
		||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。
 | 
			
		||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
 | 
			
		||||
 | 
			
		||||
@ -115,6 +116,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
<details><summary>展开日志</summary>
 | 
			
		||||
 | 
			
		||||
[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。
 | 
			
		||||
 | 
			
		||||
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
 | 
			
		||||
 | 
			
		||||
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
 | 
			
		||||
@ -249,6 +252,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
| [Phi-4](https://huggingface.co/microsoft)                         | 14B                              | phi4             |
 | 
			
		||||
| [Pixtral](https://huggingface.co/mistralai)                       | 12B                              | pixtral          |
 | 
			
		||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen             |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio      |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/72B                     | qwen2_vl         |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1       |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                |
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
 | 
			
		||||
    "tools": "the column name in the dataset containing the tool description. (default: None)",
 | 
			
		||||
    "images": "the column name in the dataset containing the image inputs. (default: None)",
 | 
			
		||||
    "videos": "the column name in the dataset containing the videos inputs. (default: None)",
 | 
			
		||||
    "audios": "the column name in the dataset containing the audios inputs. (default: None)",
 | 
			
		||||
    "chosen": "the column name in the dataset containing the chosen answers. (default: None)",
 | 
			
		||||
    "rejected": "the column name in the dataset containing the rejected answers. (default: None)",
 | 
			
		||||
    "kto_tag": "the column name in the dataset containing the kto tags. (default: None)"
 | 
			
		||||
@ -150,6 +151,10 @@ An additional column `images` is required. Please refer to the [sharegpt](#share
 | 
			
		||||
 | 
			
		||||
An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
 | 
			
		||||
 | 
			
		||||
### Multimodal Audio Dataset
 | 
			
		||||
 | 
			
		||||
An additional column `audios` is required. Please refer to the [sharegpt](#sharegpt-format) format for details.
 | 
			
		||||
 | 
			
		||||
## Sharegpt Format
 | 
			
		||||
 | 
			
		||||
### Supervised Fine-Tuning Dataset
 | 
			
		||||
@ -296,7 +301,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
 | 
			
		||||
 | 
			
		||||
- [Example dataset](mllm_demo.json)
 | 
			
		||||
 | 
			
		||||
Multimodal image datasets require a `images` column containing the paths to the input images.
 | 
			
		||||
Multimodal image datasets require an `images` column containing the paths to the input images.
 | 
			
		||||
 | 
			
		||||
The number of images should be identical to the `<image>` tokens in the conversations.
 | 
			
		||||
 | 
			
		||||
@ -374,6 +379,47 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Multimodal Audio Dataset
 | 
			
		||||
 | 
			
		||||
- [Example dataset](mllm_audio_demo.json)
 | 
			
		||||
 | 
			
		||||
Multimodal audio datasets require an `audios` column containing the paths to the input audios.
 | 
			
		||||
 | 
			
		||||
The number of audios should be identical to the `<audio>` tokens in the conversations.
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
[
 | 
			
		||||
  {
 | 
			
		||||
    "conversations": [
 | 
			
		||||
      {
 | 
			
		||||
        "from": "human",
 | 
			
		||||
        "value": "<audio>human instruction"
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        "from": "gpt",
 | 
			
		||||
        "value": "model response"
 | 
			
		||||
      }
 | 
			
		||||
    ],
 | 
			
		||||
    "audios": [
 | 
			
		||||
      "audio path (required)"
 | 
			
		||||
    ]
 | 
			
		||||
  }
 | 
			
		||||
]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Regarding the above dataset, the *dataset description* in `dataset_info.json` should be:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
"dataset_name": {
 | 
			
		||||
  "file_name": "data.json",
 | 
			
		||||
  "formatting": "sharegpt",
 | 
			
		||||
  "columns": {
 | 
			
		||||
    "messages": "conversations",
 | 
			
		||||
    "audios": "audios"
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### OpenAI Format
 | 
			
		||||
 | 
			
		||||
The openai format is simply a special case of the sharegpt format, where the first message may be a system prompt.
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@
 | 
			
		||||
    "tools": "数据集代表工具描述的表头名称(默认:None)",
 | 
			
		||||
    "images": "数据集代表图像输入的表头名称(默认:None)",
 | 
			
		||||
    "videos": "数据集代表视频输入的表头名称(默认:None)",
 | 
			
		||||
    "audios": "数据集代表音频输入的表头名称(默认:None)",
 | 
			
		||||
    "chosen": "数据集代表更优回答的表头名称(默认:None)",
 | 
			
		||||
    "rejected": "数据集代表更差回答的表头名称(默认:None)",
 | 
			
		||||
    "kto_tag": "数据集代表 KTO 标签的表头名称(默认:None)"
 | 
			
		||||
@ -150,6 +151,10 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
 | 
			
		||||
 | 
			
		||||
多模态视频数据集需要提供额外的 `videos` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
 | 
			
		||||
 | 
			
		||||
### 多模态音频数据集
 | 
			
		||||
 | 
			
		||||
多模态音频数据集需要提供额外的 `audios` 列。详情请参阅 [sharegpt](#sharegpt-格式)。
 | 
			
		||||
 | 
			
		||||
## Sharegpt 格式
 | 
			
		||||
 | 
			
		||||
### 指令监督微调数据集
 | 
			
		||||
@ -374,6 +379,48 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 多模态音频数据集
 | 
			
		||||
 | 
			
		||||
- [样例数据集](mllm_audio_demo.json)
 | 
			
		||||
 | 
			
		||||
多模态音频数据集需要额外添加一个 `audios` 列,包含输入音频的路径。
 | 
			
		||||
 | 
			
		||||
注意音频的数量必须与文本中所有 `<audio>` 标记的数量严格一致。
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
[
 | 
			
		||||
  {
 | 
			
		||||
    "conversations": [
 | 
			
		||||
      {
 | 
			
		||||
        "from": "human",
 | 
			
		||||
        "value": "<audio>人类指令"
 | 
			
		||||
      },
 | 
			
		||||
      {
 | 
			
		||||
        "from": "gpt",
 | 
			
		||||
        "value": "模型回答"
 | 
			
		||||
      }
 | 
			
		||||
    ],
 | 
			
		||||
    "audios": [
 | 
			
		||||
      "音频路径(必填)"
 | 
			
		||||
    ]
 | 
			
		||||
  }
 | 
			
		||||
]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
对于上述格式的数据,`dataset_info.json` 中的*数据集描述*应为:
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
"数据集名称": {
 | 
			
		||||
  "file_name": "data.json",
 | 
			
		||||
  "formatting": "sharegpt",
 | 
			
		||||
  "columns": {
 | 
			
		||||
    "messages": "conversations",
 | 
			
		||||
    "audios": "audios"
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
### OpenAI 格式
 | 
			
		||||
 | 
			
		||||
OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消息可能是系统提示词。
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/1.mp3
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/1.mp3
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/2.wav
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/2.wav
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								data/mllm_demo_data/3.flac
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								data/mllm_demo_data/3.flac
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@ -22,4 +22,5 @@ packaging
 | 
			
		||||
pyyaml
 | 
			
		||||
numpy<2.0.0
 | 
			
		||||
av
 | 
			
		||||
librosa
 | 
			
		||||
tyro<0.9.0
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                    "labels": feature["chosen_input_ids"] if self.train_on_prompt else feature["chosen_labels"],
 | 
			
		||||
                    "images": feature["images"],
 | 
			
		||||
                    "videos": feature["videos"],
 | 
			
		||||
                    "audios": feature["audios"],
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							@ -69,7 +69,6 @@ extra_require = {
 | 
			
		||||
        "msgpack",
 | 
			
		||||
        "referencing",
 | 
			
		||||
        "jsonschema_specifications",
 | 
			
		||||
        "librosa",
 | 
			
		||||
    ],
 | 
			
		||||
    "modelscope": ["modelscope"],
 | 
			
		||||
    "openmind": ["openmind"],
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from vllm import AsyncLLMEngine
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -68,6 +68,7 @@ class BaseEngine(ABC):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
@ -83,6 +84,7 @@ class BaseEngine(ABC):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -66,13 +66,14 @@ class ChatModel:
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of responses of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(
 | 
			
		||||
            self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
 | 
			
		||||
            self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
 | 
			
		||||
        )
 | 
			
		||||
        return task.result()
 | 
			
		||||
 | 
			
		||||
@ -83,12 +84,13 @@ class ChatModel:
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Asynchronously gets a list of responses of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
 | 
			
		||||
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
@ -97,12 +99,13 @@ class ChatModel:
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets the response token-by-token of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
 | 
			
		||||
@ -117,12 +120,15 @@ class ChatModel:
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Asynchronously gets the response token-by-token of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
 | 
			
		||||
        async for new_token in self.engine.stream_chat(
 | 
			
		||||
            messages, system, tools, images, videos, audios, **input_kwargs
 | 
			
		||||
        ):
 | 
			
		||||
            yield new_token
 | 
			
		||||
 | 
			
		||||
    def get_scores(
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import get_logits_processor
 | 
			
		||||
from ..model import load_model, load_tokenizer
 | 
			
		||||
from .base_engine import BaseEngine, Response
 | 
			
		||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from trl import PreTrainedModelWrapper
 | 
			
		||||
 | 
			
		||||
    from ..data import Template
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Tuple[Dict[str, Any], int]:
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
 | 
			
		||||
        if images is not None:
 | 
			
		||||
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
 | 
			
		||||
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if audios is not None:
 | 
			
		||||
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
 | 
			
		||||
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], processor
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or generating_args["default_system"]
 | 
			
		||||
        prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
 | 
			
		||||
        prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
            prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor
 | 
			
		||||
            prompt_ids,
 | 
			
		||||
            None,
 | 
			
		||||
            mm_input_dict["images"],
 | 
			
		||||
            mm_input_dict["videos"],
 | 
			
		||||
            mm_input_dict["audios"],
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            processor,
 | 
			
		||||
        )
 | 
			
		||||
        prompt_length = len(prompt_ids)
 | 
			
		||||
        inputs = torch.tensor([prompt_ids], device=model.device)
 | 
			
		||||
@ -184,6 +196,9 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
 | 
			
		||||
            gen_kwargs["input_ids"] = inputs
 | 
			
		||||
            gen_kwargs["tokenizer"] = tokenizer
 | 
			
		||||
            if "audio_feature_lens" in mm_inputs:
 | 
			
		||||
                gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"]
 | 
			
		||||
 | 
			
		||||
            gen_kwargs.pop("image_sizes", None)
 | 
			
		||||
 | 
			
		||||
        return gen_kwargs, prompt_length
 | 
			
		||||
@ -201,6 +216,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
 | 
			
		||||
@ -214,6 +230,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            tools,
 | 
			
		||||
            images,
 | 
			
		||||
            videos,
 | 
			
		||||
            audios,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        generate_output = model.generate(**gen_kwargs)
 | 
			
		||||
@ -252,6 +269,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
    ) -> Callable[[], str]:
 | 
			
		||||
        gen_kwargs, _ = HuggingfaceEngine._process_args(
 | 
			
		||||
@ -265,6 +283,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            tools,
 | 
			
		||||
            images,
 | 
			
		||||
            videos,
 | 
			
		||||
            audios,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        streamer = TextIteratorStreamer(
 | 
			
		||||
@ -312,6 +331,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -329,6 +349,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            tools,
 | 
			
		||||
            images,
 | 
			
		||||
            videos,
 | 
			
		||||
            audios,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self.semaphore:
 | 
			
		||||
@ -343,6 +364,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -360,6 +382,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            tools,
 | 
			
		||||
            images,
 | 
			
		||||
            videos,
 | 
			
		||||
            audios,
 | 
			
		||||
            input_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        async with self.semaphore:
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import get_device_count
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_vllm_available
 | 
			
		||||
from ..model import load_config, load_tokenizer
 | 
			
		||||
@ -39,7 +39,7 @@ if is_vllm_available():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..data.mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -109,10 +109,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
 | 
			
		||||
        if images is not None:
 | 
			
		||||
            mm_input_dict.update({"images": images, "imglens": [len(images)]})
 | 
			
		||||
            if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
@ -123,8 +124,13 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        if audios is not None:
 | 
			
		||||
            mm_input_dict.update({"audios": audios, "audlens": [len(audios)]})
 | 
			
		||||
            if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages):
 | 
			
		||||
                messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"]
 | 
			
		||||
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], self.processor
 | 
			
		||||
            messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor
 | 
			
		||||
        )
 | 
			
		||||
        paired_messages = messages + [{"role": "assistant", "content": ""}]
 | 
			
		||||
        system = system or self.generating_args["default_system"]
 | 
			
		||||
@ -202,10 +208,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        final_output = None
 | 
			
		||||
        generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
 | 
			
		||||
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
 | 
			
		||||
        async for request_output in generator:
 | 
			
		||||
            final_output = request_output
 | 
			
		||||
 | 
			
		||||
@ -230,10 +237,11 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
        generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
 | 
			
		||||
        generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs)
 | 
			
		||||
        async for result in generator:
 | 
			
		||||
            delta_text = result.outputs[0].text[len(generated_text) :]
 | 
			
		||||
            generated_text = result.outputs[0].text
 | 
			
		||||
 | 
			
		||||
@ -25,57 +25,33 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
 | 
			
		||||
    from ..hparams import DataArguments
 | 
			
		||||
    from .mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from .parser import DatasetAttr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_images(
 | 
			
		||||
    images: Union["ImageInput", Sequence["ImageInput"]],
 | 
			
		||||
def _regularize_medias(
 | 
			
		||||
    inputs: Union[Any, Sequence[Any]],
 | 
			
		||||
    dataset_attr: "DatasetAttr",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Optional[List["ImageInput"]]:
 | 
			
		||||
) -> Optional[List[Any]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Optionally concatenates image path to dataset dir when loading from local disk.
 | 
			
		||||
    Optionally concatenates media path to media dir when loading from local disk.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(images, list):
 | 
			
		||||
        images = [images]
 | 
			
		||||
    elif len(images) == 0:
 | 
			
		||||
    if not isinstance(inputs, list):
 | 
			
		||||
        inputs = [inputs]
 | 
			
		||||
    elif len(inputs) == 0:
 | 
			
		||||
        return None
 | 
			
		||||
    else:
 | 
			
		||||
        images = images[:]
 | 
			
		||||
        inputs = inputs[:]
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.load_from in ["script", "file"]:
 | 
			
		||||
        for i in range(len(images)):
 | 
			
		||||
            if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
 | 
			
		||||
                images[i] = os.path.join(data_args.image_dir, images[i])
 | 
			
		||||
        for i in range(len(inputs)):
 | 
			
		||||
            if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])):
 | 
			
		||||
                inputs[i] = os.path.join(data_args.media_dir, inputs[i])
 | 
			
		||||
 | 
			
		||||
    return images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_videos(
 | 
			
		||||
    videos: Union["VideoInput", Sequence["VideoInput"]],
 | 
			
		||||
    dataset_attr: "DatasetAttr",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
) -> Optional[List["VideoInput"]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Optionally concatenates video path to dataset dir when loading from local disk.
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(videos, list):
 | 
			
		||||
        videos = [videos]
 | 
			
		||||
    elif len(videos) == 0:
 | 
			
		||||
        return None
 | 
			
		||||
    else:
 | 
			
		||||
        videos = videos[:]
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.load_from in ["script", "file"]:
 | 
			
		||||
        for i in range(len(videos)):
 | 
			
		||||
            if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
 | 
			
		||||
                videos[i] = os.path.join(data_args.image_dir, videos[i])
 | 
			
		||||
 | 
			
		||||
    return videos
 | 
			
		||||
    return inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_alpaca(
 | 
			
		||||
@ -121,15 +97,15 @@ def convert_alpaca(
 | 
			
		||||
    else:  # unsupervised
 | 
			
		||||
        response = []
 | 
			
		||||
 | 
			
		||||
    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    output = {
 | 
			
		||||
        "_prompt": prompt,
 | 
			
		||||
        "_response": response,
 | 
			
		||||
        "_system": example[dataset_attr.system] if dataset_attr.system else "",
 | 
			
		||||
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
 | 
			
		||||
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
        "_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
        "_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
 | 
			
		||||
    }
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
@ -214,15 +190,15 @@ def convert_sharegpt(
 | 
			
		||||
        logger.warning_rank0("Skipping this abnormal example.")
 | 
			
		||||
        prompt, response = [], []
 | 
			
		||||
 | 
			
		||||
    convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
    output = {
 | 
			
		||||
        "_prompt": prompt,
 | 
			
		||||
        "_response": response,
 | 
			
		||||
        "_system": system,
 | 
			
		||||
        "_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
 | 
			
		||||
        "_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
        "_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
 | 
			
		||||
        "_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
 | 
			
		||||
        "_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
 | 
			
		||||
    }
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
@ -241,6 +217,7 @@ def align_dataset(
 | 
			
		||||
        _tools: "...",
 | 
			
		||||
        _images: [],
 | 
			
		||||
        _videos: [],
 | 
			
		||||
        _audios: [],
 | 
			
		||||
    """
 | 
			
		||||
    if dataset_attr.formatting == "alpaca":
 | 
			
		||||
        convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
 | 
			
		||||
 | 
			
		||||
@ -18,11 +18,12 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator that supports VLMs.
 | 
			
		||||
 | 
			
		||||
    Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
 | 
			
		||||
    Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    template: Optional["Template"] = None
 | 
			
		||||
@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
            raise ValueError("Template is required for MultiModalDataCollator.")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
 | 
			
		||||
        batch_images, batch_videos, batch_audios = [], [], []
 | 
			
		||||
        batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
 | 
			
		||||
        for feature in features:
 | 
			
		||||
            images = feature.pop("images", None) or []
 | 
			
		||||
            videos = feature.pop("videos", None) or []
 | 
			
		||||
            audios = feature.pop("audios", None) or []
 | 
			
		||||
            batch_images.extend(images)
 | 
			
		||||
            batch_videos.extend(videos)
 | 
			
		||||
            batch_audios.extend(audios)
 | 
			
		||||
            batch_imglens.append(len(images))
 | 
			
		||||
            batch_vidlens.append(len(videos))
 | 
			
		||||
            batch_audlens.append(len(audios))
 | 
			
		||||
            batch_input_ids.append(feature["input_ids"])
 | 
			
		||||
 | 
			
		||||
        fake_input_ids = None
 | 
			
		||||
        if (
 | 
			
		||||
            self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
 | 
			
		||||
            self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
 | 
			
		||||
        ):  # avoid process hanging in zero3/fsdp case
 | 
			
		||||
            fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
 | 
			
		||||
            fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
 | 
			
		||||
            fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
 | 
			
		||||
            fake_messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
                fake_messages, fake_images, [], [], self.processor
 | 
			
		||||
            )
 | 
			
		||||
            fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
 | 
			
		||||
            fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
 | 
			
		||||
                fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
 | 
			
		||||
                fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
 | 
			
		||||
            )
 | 
			
		||||
            batch_images = fake_images
 | 
			
		||||
            batch_imglens[0] = 1
 | 
			
		||||
            batch_input_ids[0] = features[0]["input_ids"]
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
 | 
			
		||||
        ):  # avoid process hanging in zero3/fsdp case
 | 
			
		||||
            fake_messages = [{"role": "user", "content": AUDIO_PLACEHOLDER}]
 | 
			
		||||
            fake_audios = [np.zeros(1600)]
 | 
			
		||||
            fake_messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
                fake_messages, [], [], fake_audios, self.processor
 | 
			
		||||
            )
 | 
			
		||||
            fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
 | 
			
		||||
            fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
 | 
			
		||||
                fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
 | 
			
		||||
            )
 | 
			
		||||
            batch_audios = fake_audios
 | 
			
		||||
            batch_audlens[0] = 1
 | 
			
		||||
            batch_input_ids[0] = features[0]["input_ids"]
 | 
			
		||||
 | 
			
		||||
        if fake_input_ids is not None:
 | 
			
		||||
            if self.tokenizer.padding_side == "right":
 | 
			
		||||
                features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
 | 
			
		||||
                features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
 | 
			
		||||
@ -120,12 +149,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
                features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
 | 
			
		||||
                features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
 | 
			
		||||
 | 
			
		||||
            batch_images = fake_images
 | 
			
		||||
            batch_imglens[0] = 1
 | 
			
		||||
            batch_input_ids[0] = features[0]["input_ids"]
 | 
			
		||||
 | 
			
		||||
        mm_inputs = self.template.mm_plugin.get_mm_inputs(
 | 
			
		||||
            batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
 | 
			
		||||
            batch_images,
 | 
			
		||||
            batch_videos,
 | 
			
		||||
            batch_audios,
 | 
			
		||||
            batch_imglens,
 | 
			
		||||
            batch_vidlens,
 | 
			
		||||
            batch_audlens,
 | 
			
		||||
            batch_input_ids,
 | 
			
		||||
            self.processor,
 | 
			
		||||
        )
 | 
			
		||||
        if "token_type_ids" in mm_inputs:
 | 
			
		||||
            token_type_ids = mm_inputs.pop("token_type_ids")
 | 
			
		||||
@ -208,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                    "labels": feature[f"{key}_labels"],
 | 
			
		||||
                    "images": feature["images"],
 | 
			
		||||
                    "videos": feature["videos"],
 | 
			
		||||
                    "audios": feature["audios"],
 | 
			
		||||
                }
 | 
			
		||||
                concatenated_features.append(target_feature)
 | 
			
		||||
 | 
			
		||||
@ -231,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                "labels": feature["labels"],
 | 
			
		||||
                "images": feature["images"],
 | 
			
		||||
                "videos": feature["videos"],
 | 
			
		||||
                "audios": feature["audios"],
 | 
			
		||||
            }
 | 
			
		||||
            kl_feature = {
 | 
			
		||||
                "input_ids": feature["kl_input_ids"],
 | 
			
		||||
@ -238,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
                "labels": feature["kl_labels"],
 | 
			
		||||
                "images": feature["images"],
 | 
			
		||||
                "videos": feature["videos"],
 | 
			
		||||
                "audios": feature["audios"],
 | 
			
		||||
            }
 | 
			
		||||
            target_features.append(target_feature)
 | 
			
		||||
            kl_features.append(kl_feature)
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,17 @@ import torch
 | 
			
		||||
from transformers.image_utils import get_image_size, to_numpy_array
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import (
 | 
			
		||||
    is_librosa_available,
 | 
			
		||||
    is_pillow_available,
 | 
			
		||||
    is_pyav_available,
 | 
			
		||||
    is_transformers_version_greater_than,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_librosa_available():
 | 
			
		||||
    import librosa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
@ -31,7 +40,9 @@ if is_transformers_version_greater_than("4.45.0"):
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from av.stream import Stream
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
 | 
			
		||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
			
		||||
 | 
			
		||||
    class EncodedImage(TypedDict):
 | 
			
		||||
@ -40,6 +51,7 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
    ImageInput = Union[str, bytes, EncodedImage, ImageObject]
 | 
			
		||||
    VideoInput = str
 | 
			
		||||
    AudioInput = Union[str, NDArray]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(
 | 
			
		||||
@ -60,15 +72,17 @@ def _get_paligemma_token_type_ids(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BasePlugin:
 | 
			
		||||
    def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
 | 
			
		||||
    def __init__(self, image_token: Optional[str], video_token: Optional[str], audio_token: Optional[str]) -> None:
 | 
			
		||||
        self.image_token = image_token
 | 
			
		||||
        self.video_token = video_token
 | 
			
		||||
        self.audio_token = audio_token
 | 
			
		||||
        self.expand_mm_tokens = True
 | 
			
		||||
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Validates if this model accepts the input modalities.
 | 
			
		||||
@ -83,11 +97,16 @@ class BasePlugin:
 | 
			
		||||
                "This model does not support video input. Please check whether the correct `template` is used."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0 and self.audio_token is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This model does not support audio input. Please check whether the correct `template` is used."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes a single image.
 | 
			
		||||
        """
 | 
			
		||||
        image_resolution: int = kwargs.get("image_resolution")
 | 
			
		||||
        image_resolution: int = kwargs["image_resolution"]
 | 
			
		||||
        if (image.width * image.height) > image_resolution:
 | 
			
		||||
            resize_factor = math.sqrt(image_resolution / (image.width * image.height))
 | 
			
		||||
            width, height = int(image.width * resize_factor), int(image.height * resize_factor)
 | 
			
		||||
@ -102,8 +121,8 @@ class BasePlugin:
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes video sample frames according to fps.
 | 
			
		||||
        """
 | 
			
		||||
        video_fps: float = kwargs.get("video_fps")
 | 
			
		||||
        video_maxlen: int = kwargs.get("video_maxlen")
 | 
			
		||||
        video_fps: float = kwargs["video_fps"]
 | 
			
		||||
        video_maxlen: int = kwargs["video_maxlen"]
 | 
			
		||||
        total_frames = video_stream.frames
 | 
			
		||||
        sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
 | 
			
		||||
        sample_frames = min(total_frames, video_maxlen, sample_frames)
 | 
			
		||||
@ -126,7 +145,7 @@ class BasePlugin:
 | 
			
		||||
                    image = Image.open(image["path"])
 | 
			
		||||
 | 
			
		||||
            if not isinstance(image, ImageObject):
 | 
			
		||||
                raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
 | 
			
		||||
                raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
 | 
			
		||||
 | 
			
		||||
            results.append(self._preprocess_image(image, **kwargs))
 | 
			
		||||
 | 
			
		||||
@ -154,10 +173,28 @@ class BasePlugin:
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def _regularize_audios(self, audios: Sequence["AudioInput"], **kwargs) -> List["NDArray"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Regularizes audios to avoid error. Including reading and resampling.
 | 
			
		||||
        """
 | 
			
		||||
        results = []
 | 
			
		||||
        sampling_rate = kwargs["sampling_rate"]
 | 
			
		||||
        for audio in audios:
 | 
			
		||||
            if isinstance(audio, str):
 | 
			
		||||
                audio = librosa.load(audio, sr=sampling_rate)[0]
 | 
			
		||||
 | 
			
		||||
            if not isinstance(audio, np.ndarray):
 | 
			
		||||
                raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
 | 
			
		||||
 | 
			
		||||
            results.append(audio)
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
    ) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        r"""
 | 
			
		||||
@ -172,15 +209,17 @@ class BasePlugin:
 | 
			
		||||
 | 
			
		||||
        It holds num_patches == torch.prod(image_grid_thw)
 | 
			
		||||
        """
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
 | 
			
		||||
        video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
 | 
			
		||||
        input_dict = {"images": None}  # default key
 | 
			
		||||
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_resolution=getattr(processor, "image_resolution", 768 * 768),
 | 
			
		||||
            )
 | 
			
		||||
            input_dict["images"] = images
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
@ -189,16 +228,23 @@ class BasePlugin:
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            input_dict["videos"] = videos
 | 
			
		||||
            mm_inputs.update(video_processor(videos, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if image_processor != video_processor:
 | 
			
		||||
            if input_dict.get("images") is not None:
 | 
			
		||||
                mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
 | 
			
		||||
            if input_dict.get("videos") is not None:
 | 
			
		||||
                mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
 | 
			
		||||
        elif input_dict.get("images") is not None or input_dict.get("videos") is not None:  # same processor (qwen2-vl)
 | 
			
		||||
            mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                feature_extractor(
 | 
			
		||||
                    audios,
 | 
			
		||||
                    sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
                    return_attention_mask=True,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -207,12 +253,13 @@ class BasePlugin:
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes input messages before tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
@ -221,21 +268,24 @@ class BasePlugin:
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes token ids after tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
@ -247,10 +297,11 @@ class BasePlugin:
 | 
			
		||||
            videos: a list of video inputs, shape (num_videos,)
 | 
			
		||||
            imglens: number of images in each sample, shape (batch_size,)
 | 
			
		||||
            vidlens: number of videos in each sample, shape (batch_size,)
 | 
			
		||||
            audlens: number of audios in each sample, shape (batch_size,)
 | 
			
		||||
            batch_ids: token ids of input samples, shape (batch_size, seq_len)
 | 
			
		||||
            processor: a processor for pre-processing images and videos
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -261,9 +312,10 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
@ -285,13 +337,15 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
@ -301,12 +355,13 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if "image_sizes" in mm_inputs:
 | 
			
		||||
            image_sizes = iter(mm_inputs["image_sizes"])
 | 
			
		||||
 | 
			
		||||
@ -339,13 +394,15 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
@ -355,12 +412,13 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if "pixel_values" in mm_inputs:
 | 
			
		||||
            image_sizes = iter(mm_inputs["image_sizes"])
 | 
			
		||||
            height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
 | 
			
		||||
@ -408,13 +466,15 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
@ -424,26 +484,30 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        num_video_tokens = 0
 | 
			
		||||
        num_audio_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        audio_inputs = {}
 | 
			
		||||
        audio_parts = []
 | 
			
		||||
        if len(images) != 0 and len(videos) != 0:
 | 
			
		||||
            raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            max_slice_nums = 2
 | 
			
		||||
            use_image_id = False
 | 
			
		||||
            mm_inputs = self._get_mm_inputs([], videos, processor)
 | 
			
		||||
            mm_inputs = self._get_mm_inputs([], videos, [], processor)
 | 
			
		||||
        else:
 | 
			
		||||
            max_slice_nums = image_processor.max_slice_nums
 | 
			
		||||
            use_image_id = image_processor.use_image_id
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
        for i, message in enumerate(messages):
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
 | 
			
		||||
@ -454,15 +518,25 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", "(<image>./</image>)")
 | 
			
		||||
            while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                audio_parts.append(i)
 | 
			
		||||
                content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
 | 
			
		||||
                num_audio_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
 | 
			
		||||
                "{{audio}}", "(<audio>./</audio>)"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if num_image_tokens > 0:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, [], processor)
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, [], [], processor)
 | 
			
		||||
 | 
			
		||||
        if num_audio_tokens > 0:
 | 
			
		||||
            audio_parts_ls = [audio_parts]
 | 
			
		||||
            audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls, ret_phs=True)
 | 
			
		||||
 | 
			
		||||
        if mm_inputs:
 | 
			
		||||
            pattern = "(<image>./</image>)"
 | 
			
		||||
            image_sizes = mm_inputs["image_sizes"]
 | 
			
		||||
 | 
			
		||||
            for index, message in enumerate(messages):
 | 
			
		||||
                text = message["content"]
 | 
			
		||||
                image_tags = re.findall(pattern, text)
 | 
			
		||||
@ -480,12 +554,29 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
                final_text += text_chunks[-1]
 | 
			
		||||
                messages[index]["content"] = final_text
 | 
			
		||||
 | 
			
		||||
        if audio_inputs:
 | 
			
		||||
            pattern = "(<audio>./</audio>)"
 | 
			
		||||
            for index, message in enumerate(messages):
 | 
			
		||||
                text = message["content"]
 | 
			
		||||
                audio_tags = re.findall(pattern, text)
 | 
			
		||||
                text_chunks = text.split(pattern)
 | 
			
		||||
                final_text = ""
 | 
			
		||||
                for i in range(len(audio_tags)):
 | 
			
		||||
                    audio_placeholder = audio_inputs["audio_phs"][0][i]
 | 
			
		||||
                    final_text = final_text + text_chunks[i] + audio_placeholder
 | 
			
		||||
 | 
			
		||||
                final_text += text_chunks[-1]
 | 
			
		||||
                messages[index]["content"] = final_text
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
@ -493,6 +584,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
@ -528,6 +620,30 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
            video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
 | 
			
		||||
            mm_inputs.update(video_inputs)
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            audio_parts_ls = kwargs.get("audio_parts_ls", None)
 | 
			
		||||
            new_audios = []
 | 
			
		||||
            for audio in audios:
 | 
			
		||||
                if not isinstance(audio, np.ndarray):
 | 
			
		||||
                    audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
 | 
			
		||||
                new_audios.append(audio)
 | 
			
		||||
 | 
			
		||||
            audios_ls = []
 | 
			
		||||
            idx = 0
 | 
			
		||||
            for audio_parts in audio_parts_ls:
 | 
			
		||||
                audios_ls.append(new_audios[idx : idx + len(audio_parts)])
 | 
			
		||||
                idx += len(audio_parts)
 | 
			
		||||
 | 
			
		||||
            audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
 | 
			
		||||
                audios_ls,
 | 
			
		||||
                audio_parts_ls,
 | 
			
		||||
                chunk_input=True,
 | 
			
		||||
                sampling_rate=16000,
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
 | 
			
		||||
            if kwargs.get("ret_phs", False):
 | 
			
		||||
                mm_inputs.update({"audio_phs": audio_phs})
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
@ -535,12 +651,16 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
 | 
			
		||||
        # image bound
 | 
			
		||||
        image_bounds_list = []
 | 
			
		||||
        valid_image_nums_ls = []
 | 
			
		||||
        for i, input_ids in enumerate(batch_ids):
 | 
			
		||||
@ -561,8 +681,38 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
            )
 | 
			
		||||
            image_bounds_list.append(image_bounds)
 | 
			
		||||
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
 | 
			
		||||
        if "tgt_sizes" not in mm_inputs:
 | 
			
		||||
            dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
 | 
			
		||||
            mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
 | 
			
		||||
 | 
			
		||||
        mm_inputs.update({"image_bound": image_bounds_list})
 | 
			
		||||
 | 
			
		||||
        if len(audios) > 0:
 | 
			
		||||
            # audio bound
 | 
			
		||||
            audio_bounds_ls = []
 | 
			
		||||
            spk_bounds_ls = []
 | 
			
		||||
            audio_parts_ls = []
 | 
			
		||||
 | 
			
		||||
            for input_ids, audiolen in zip(batch_ids, audlens):
 | 
			
		||||
                input_ids_ = torch.tensor(input_ids)
 | 
			
		||||
                audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
 | 
			
		||||
                audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
 | 
			
		||||
                assert len(audio_start_idx) == len(audio_end_idx)
 | 
			
		||||
                audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
 | 
			
		||||
                audio_bounds_ls.append(audio_bounds)
 | 
			
		||||
                audio_parts_ls.append(list(range(audiolen)))
 | 
			
		||||
 | 
			
		||||
                spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
 | 
			
		||||
                spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
 | 
			
		||||
                assert len(spk_start_idx) == len(spk_end_idx)
 | 
			
		||||
                spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
 | 
			
		||||
                spk_bounds_ls.append(spk_bounds)
 | 
			
		||||
 | 
			
		||||
            audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls)
 | 
			
		||||
            mm_inputs.update(audio_inputs)
 | 
			
		||||
            mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -573,9 +723,10 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -593,6 +744,7 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
@ -617,17 +769,20 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return image_processor(batch_images, return_tensors="pt")
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens=imglens)
 | 
			
		||||
        num_tiles = mm_inputs.pop("num_tiles")
 | 
			
		||||
        image_token_id = getattr(processor, "image_token_id")
 | 
			
		||||
        max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
 | 
			
		||||
@ -652,9 +807,10 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -677,10 +833,11 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        labels: Optional[List[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_images = len(images)
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0  # skip mm token
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
@ -695,14 +852,16 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        seqlens = [len(input_ids) for input_ids in batch_ids]
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -714,9 +873,10 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        patch_size = getattr(processor, "patch_size")
 | 
			
		||||
        image_token = getattr(processor, "image_token")
 | 
			
		||||
        image_break_token = getattr(processor, "image_break_token")
 | 
			
		||||
@ -724,7 +884,7 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        image_input_sizes = mm_inputs.get("image_sizes", None)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
@ -759,13 +919,15 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if mm_inputs.get("pixel_values"):
 | 
			
		||||
            mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
 | 
			
		||||
 | 
			
		||||
@ -773,6 +935,58 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        bos_token: str = getattr(processor, "audio_bos_token")
 | 
			
		||||
        eos_token: str = getattr(processor, "audio_eos_token")
 | 
			
		||||
        mm_inputs = self._get_mm_inputs([], [], audios, processor)
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
 | 
			
		||||
 | 
			
		||||
        num_audio_tokens = 0
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                audio_length = audio_lengths.pop(0)
 | 
			
		||||
                input_length = (audio_length - 1) // 2 + 1
 | 
			
		||||
                audio_seqlen = (input_length - 2) // 2 + 1 if self.expand_mm_tokens else 1
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_audio_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
 | 
			
		||||
@ -820,12 +1034,13 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        image_grid_thw = mm_inputs.get("image_grid_thw", [])
 | 
			
		||||
        video_grid_thw = mm_inputs.get("video_grid_thw", [])
 | 
			
		||||
 | 
			
		||||
@ -868,13 +1083,15 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
 | 
			
		||||
            video_fps = getattr(processor, "video_fps", 2.0)
 | 
			
		||||
@ -892,12 +1109,13 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        num_frames = 0
 | 
			
		||||
        has_images = "pixel_values_images" in mm_inputs
 | 
			
		||||
        has_videos = "pixel_values_videos" in mm_inputs
 | 
			
		||||
@ -945,13 +1163,15 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, processor)
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
@ -963,6 +1183,7 @@ PLUGINS = {
 | 
			
		||||
    "mllama": MllamaPlugin,
 | 
			
		||||
    "paligemma": PaliGemmaPlugin,
 | 
			
		||||
    "pixtral": PixtralPlugin,
 | 
			
		||||
    "qwen2_audio": Qwen2AudioPlugin,
 | 
			
		||||
    "qwen2_vl": Qwen2vlPlugin,
 | 
			
		||||
    "video_llava": VideoLlavaPlugin,
 | 
			
		||||
}
 | 
			
		||||
@ -972,9 +1193,10 @@ def get_mm_plugin(
 | 
			
		||||
    name: str,
 | 
			
		||||
    image_token: Optional[str] = None,
 | 
			
		||||
    video_token: Optional[str] = None,
 | 
			
		||||
    audio_token: Optional[str] = None,
 | 
			
		||||
) -> "BasePlugin":
 | 
			
		||||
    plugin_class = PLUGINS.get(name, None)
 | 
			
		||||
    if plugin_class is None:
 | 
			
		||||
        raise ValueError(f"Multimodal plugin `{name}` not found.")
 | 
			
		||||
 | 
			
		||||
    return plugin_class(image_token, video_token)
 | 
			
		||||
    return plugin_class(image_token, video_token, audio_token)
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,7 @@ class DatasetAttr:
 | 
			
		||||
    tools: Optional[str] = None
 | 
			
		||||
    images: Optional[str] = None
 | 
			
		||||
    videos: Optional[str] = None
 | 
			
		||||
    audios: Optional[str] = None
 | 
			
		||||
    # rlhf columns
 | 
			
		||||
    chosen: Optional[str] = None
 | 
			
		||||
    rejected: Optional[str] = None
 | 
			
		||||
@ -135,7 +136,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
 | 
			
		||||
        dataset_attr.set_attr("num_samples", dataset_info[name])
 | 
			
		||||
 | 
			
		||||
        if "columns" in dataset_info[name]:
 | 
			
		||||
            column_names = ["system", "tools", "images", "videos", "chosen", "rejected", "kto_tag"]
 | 
			
		||||
            column_names = ["system", "tools", "images", "videos", "audios", "chosen", "rejected", "kto_tag"]
 | 
			
		||||
            if dataset_attr.formatting == "alpaca":
 | 
			
		||||
                column_names.extend(["prompt", "query", "response", "history"])
 | 
			
		||||
            else:
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,6 +39,7 @@ def _encode_feedback_example(
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    audios: Sequence["AudioInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -56,8 +57,8 @@ def _encode_feedback_example(
 | 
			
		||||
    else:
 | 
			
		||||
        kl_messages = prompt + [kl_response[1]]
 | 
			
		||||
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, processor)
 | 
			
		||||
    kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
 | 
			
		||||
    kl_messages = template.mm_plugin.process_messages(kl_messages, images, videos, audios, processor)
 | 
			
		||||
    prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
 | 
			
		||||
    kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
 | 
			
		||||
 | 
			
		||||
@ -65,8 +66,12 @@ def _encode_feedback_example(
 | 
			
		||||
        response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        kl_response_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
        prompt_ids, None, images, videos, audios, tokenizer, processor
 | 
			
		||||
    )
 | 
			
		||||
    kl_prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
        kl_prompt_ids, None, images, videos, audios, tokenizer, processor
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -107,6 +112,7 @@ def preprocess_feedback_dataset(
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            audios=examples["_audios"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -121,6 +127,7 @@ def preprocess_feedback_dataset(
 | 
			
		||||
        model_inputs["kto_tags"].append(kto_tag)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
        model_inputs["audios"].append(examples["_audios"][i])
 | 
			
		||||
 | 
			
		||||
    desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
 | 
			
		||||
    undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,13 +38,14 @@ def _encode_pairwise_example(
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    audios: Sequence["AudioInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
    cutoff_len: int,
 | 
			
		||||
) -> Tuple[List[int], List[int], List[int], List[int]]:
 | 
			
		||||
    chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, processor)
 | 
			
		||||
    rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, processor)
 | 
			
		||||
    chosen_messages = template.mm_plugin.process_messages(prompt + [response[0]], images, videos, audios, processor)
 | 
			
		||||
    rejected_messages = template.mm_plugin.process_messages(prompt + [response[1]], images, videos, audios, processor)
 | 
			
		||||
    prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
 | 
			
		||||
    _, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
 | 
			
		||||
 | 
			
		||||
@ -52,7 +53,9 @@ def _encode_pairwise_example(
 | 
			
		||||
        chosen_ids += [tokenizer.eos_token_id]
 | 
			
		||||
        rejected_ids += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    prompt_ids, _ = template.mm_plugin.process_token_ids(
 | 
			
		||||
        prompt_ids, None, images, videos, audios, tokenizer, processor
 | 
			
		||||
    )
 | 
			
		||||
    # consider the response is more important
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), cutoff_len)
 | 
			
		||||
    prompt_ids = prompt_ids[:source_len]
 | 
			
		||||
@ -89,6 +92,7 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            audios=examples["_audios"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -102,6 +106,7 @@ def preprocess_pairwise_dataset(
 | 
			
		||||
        model_inputs["rejected_labels"].append(rejected_labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
        model_inputs["audios"].append(examples["_audios"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,7 @@ def _encode_supervised_example(
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    audios: Sequence["AudioInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -45,8 +46,8 @@ def _encode_supervised_example(
 | 
			
		||||
    train_on_prompt: bool,
 | 
			
		||||
    mask_history: bool,
 | 
			
		||||
) -> Tuple[List[int], List[int]]:
 | 
			
		||||
    messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
 | 
			
		||||
    input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(prompt + response, images, videos, audios, processor)
 | 
			
		||||
    input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, audios, tokenizer, processor)
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
 | 
			
		||||
    total_length = len(input_ids) + (1 if template.efficient_eos else 0)
 | 
			
		||||
    if mask_history:
 | 
			
		||||
@ -111,6 +112,7 @@ def preprocess_supervised_dataset(
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            audios=examples["_audios"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -123,6 +125,7 @@ def preprocess_supervised_dataset(
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
        model_inputs["audios"].append(examples["_audios"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
@ -138,7 +141,7 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
    # build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
 | 
			
		||||
    # and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
 | 
			
		||||
    valid_num = 0
 | 
			
		||||
    batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], []
 | 
			
		||||
    batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
 | 
			
		||||
    lengths = []
 | 
			
		||||
    length2indexes = defaultdict(list)
 | 
			
		||||
    for i in range(len(examples["_prompt"])):
 | 
			
		||||
@ -155,6 +158,7 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            audios=examples["_audios"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -172,19 +176,21 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
            batch_labels.append(labels)
 | 
			
		||||
            batch_images.append(examples["_images"][i] or [])
 | 
			
		||||
            batch_videos.append(examples["_videos"][i] or [])
 | 
			
		||||
            batch_audios.append(examples["_audios"][i] or [])
 | 
			
		||||
            valid_num += 1
 | 
			
		||||
 | 
			
		||||
    model_inputs = defaultdict(list)
 | 
			
		||||
    knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1)  # reserved for the padding token
 | 
			
		||||
    for knapsack in knapsacks:
 | 
			
		||||
        packed_input_ids, packed_attention_masks, packed_labels = [], [], []
 | 
			
		||||
        packed_images, packed_videos = [], []
 | 
			
		||||
        packed_images, packed_videos, packed_audios = [], [], []
 | 
			
		||||
        for i, length in enumerate(knapsack):
 | 
			
		||||
            index = length2indexes[length].pop()
 | 
			
		||||
            packed_input_ids += batch_input_ids[index]
 | 
			
		||||
            packed_labels += batch_labels[index]
 | 
			
		||||
            packed_images += batch_images[index]
 | 
			
		||||
            packed_videos += batch_videos[index]
 | 
			
		||||
            packed_audios += batch_audios[index]
 | 
			
		||||
            if data_args.neat_packing:
 | 
			
		||||
                packed_attention_masks += [i + 1] * len(batch_input_ids[index])  # start from 1
 | 
			
		||||
            else:
 | 
			
		||||
@ -207,6 +213,7 @@ def preprocess_packed_supervised_dataset(
 | 
			
		||||
        model_inputs["labels"].append(packed_labels)
 | 
			
		||||
        model_inputs["images"].append(packed_images or None)
 | 
			
		||||
        model_inputs["videos"].append(packed_videos or None)
 | 
			
		||||
        model_inputs["audios"].append(packed_audios or None)
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import DataArguments
 | 
			
		||||
    from ..mm_plugin import ImageInput, VideoInput
 | 
			
		||||
    from ..mm_plugin import AudioInput, ImageInput, VideoInput
 | 
			
		||||
    from ..template import Template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,7 @@ def _encode_unsupervised_example(
 | 
			
		||||
    tools: Optional[str],
 | 
			
		||||
    images: Sequence["ImageInput"],
 | 
			
		||||
    videos: Sequence["VideoInput"],
 | 
			
		||||
    audios: Sequence["AudioInput"],
 | 
			
		||||
    template: "Template",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"],
 | 
			
		||||
@ -48,12 +49,12 @@ def _encode_unsupervised_example(
 | 
			
		||||
    else:
 | 
			
		||||
        messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
 | 
			
		||||
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, processor)
 | 
			
		||||
    messages = template.mm_plugin.process_messages(messages, images, videos, audios, processor)
 | 
			
		||||
    input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
 | 
			
		||||
    if template.efficient_eos:
 | 
			
		||||
        labels += [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
    input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, tokenizer, processor)
 | 
			
		||||
    input_ids, _ = template.mm_plugin.process_token_ids(input_ids, None, images, videos, audios, tokenizer, processor)
 | 
			
		||||
    source_len, target_len = infer_seqlen(len(input_ids), len(labels), cutoff_len)
 | 
			
		||||
    input_ids = input_ids[:source_len]
 | 
			
		||||
    labels = labels[:target_len]
 | 
			
		||||
@ -83,6 +84,7 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
            tools=examples["_tools"][i],
 | 
			
		||||
            images=examples["_images"][i] or [],
 | 
			
		||||
            videos=examples["_videos"][i] or [],
 | 
			
		||||
            audios=examples["_audios"][i] or [],
 | 
			
		||||
            template=template,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            processor=processor,
 | 
			
		||||
@ -93,6 +95,7 @@ def preprocess_unsupervised_dataset(
 | 
			
		||||
        model_inputs["labels"].append(labels)
 | 
			
		||||
        model_inputs["images"].append(examples["_images"][i])
 | 
			
		||||
        model_inputs["videos"].append(examples["_videos"][i])
 | 
			
		||||
        model_inputs["audios"].append(examples["_audios"][i])
 | 
			
		||||
 | 
			
		||||
    return model_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -890,7 +890,7 @@ _register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from chatml template
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="llava_next_qwen",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
@ -979,7 +979,7 @@ _register_template(
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1144,6 +1144,18 @@ _register_template(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from chatml template
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="qwen2_audio",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    default_system="You are a helpful assistant.",
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen2_audio", audio_token="<|AUDIO|>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
_register_template(
 | 
			
		||||
    name="qwen2_vl",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,8 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
 | 
			
		||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>")
 | 
			
		||||
 | 
			
		||||
CHECKPOINT_NAMES = {
 | 
			
		||||
    SAFE_ADAPTER_WEIGHTS_NAME,
 | 
			
		||||
    ADAPTER_WEIGHTS_NAME,
 | 
			
		||||
@ -58,6 +60,8 @@ METHODS = ["full", "freeze", "lora"]
 | 
			
		||||
 | 
			
		||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
 | 
			
		||||
 | 
			
		||||
MULTIMODAL_SUPPORTED_MODELS = set()
 | 
			
		||||
 | 
			
		||||
PEFT_METHODS = {"lora"}
 | 
			
		||||
 | 
			
		||||
RUNNING_LOG = "running_log.txt"
 | 
			
		||||
@ -89,8 +93,6 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
 | 
			
		||||
 | 
			
		||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
 | 
			
		||||
 | 
			
		||||
VISION_MODELS = set()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DownloadSource(str, Enum):
 | 
			
		||||
    DEFAULT = "hf"
 | 
			
		||||
@ -101,14 +103,16 @@ class DownloadSource(str, Enum):
 | 
			
		||||
def register_model_group(
 | 
			
		||||
    models: Dict[str, Dict[DownloadSource, str]],
 | 
			
		||||
    template: Optional[str] = None,
 | 
			
		||||
    vision: bool = False,
 | 
			
		||||
    multimodal: bool = False,
 | 
			
		||||
) -> None:
 | 
			
		||||
    for name, path in models.items():
 | 
			
		||||
        SUPPORTED_MODELS[name] = path
 | 
			
		||||
        if template is not None and (any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or vision):
 | 
			
		||||
        if template is not None and (
 | 
			
		||||
            any(suffix in name for suffix in ("-Chat", "-Distill", "-Instruct")) or multimodal
 | 
			
		||||
        ):
 | 
			
		||||
            DEFAULT_TEMPLATE[name] = template
 | 
			
		||||
        if vision:
 | 
			
		||||
            VISION_MODELS.add(name)
 | 
			
		||||
        if multimodal:
 | 
			
		||||
            MULTIMODAL_SUPPORTED_MODELS.add(name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
@ -1030,7 +1034,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="mllama",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1046,7 +1050,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1062,7 +1066,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1074,7 +1078,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_mistral",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1086,7 +1090,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_llama3",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1098,7 +1102,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_yi",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1114,7 +1118,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_qwen",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1130,7 +1134,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_video",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1142,7 +1146,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_video_mistral",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1157,7 +1161,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llava_next_video_yi",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1207,7 +1211,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="minicpm_v",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1219,7 +1223,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="minicpm_v",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1424,7 +1428,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="paligemma",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1468,7 +1472,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="paligemma",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1551,7 +1555,7 @@ register_model_group(
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    template="pixtral",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2134,6 +2138,22 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen2-Audio-7B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2-Audio-7B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2-Audio-7B-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2-Audio-7B-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_audio",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen2-VL-2B-Instruct": {
 | 
			
		||||
@ -2204,7 +2224,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_vl",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2329,7 +2349,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="video_llava",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -2556,7 +2576,7 @@ register_model_group(
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="yi_vl",
 | 
			
		||||
    vision=True,
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -42,6 +42,10 @@ def is_pyav_available():
 | 
			
		||||
    return _is_package_available("av")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_librosa_available():
 | 
			
		||||
    return _is_package_available("librosa")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_fastapi_available():
 | 
			
		||||
    return _is_package_available("fastapi")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -41,9 +41,9 @@ class DataArguments:
 | 
			
		||||
        default="data",
 | 
			
		||||
        metadata={"help": "Path to the folder containing the datasets."},
 | 
			
		||||
    )
 | 
			
		||||
    image_dir: Optional[str] = field(
 | 
			
		||||
    media_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
 | 
			
		||||
        metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
 | 
			
		||||
    )
 | 
			
		||||
    cutoff_len: int = field(
 | 
			
		||||
        default=2048,
 | 
			
		||||
@ -133,8 +133,8 @@ class DataArguments:
 | 
			
		||||
        self.dataset = split_arg(self.dataset)
 | 
			
		||||
        self.eval_dataset = split_arg(self.eval_dataset)
 | 
			
		||||
 | 
			
		||||
        if self.image_dir is None:
 | 
			
		||||
            self.image_dir = self.dataset_dir
 | 
			
		||||
        if self.media_dir is None:
 | 
			
		||||
            self.media_dir = self.dataset_dir
 | 
			
		||||
 | 
			
		||||
        if self.dataset is None and self.val_size > 1e-6:
 | 
			
		||||
            raise ValueError("Cannot specify `val_size` if `dataset` is None.")
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,14 @@ import os
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoModelForCausalLM,
 | 
			
		||||
    AutoModelForSeq2SeqLM,
 | 
			
		||||
    AutoModelForVision2Seq,
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
)
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
@ -142,6 +149,8 @@ def load_model(
 | 
			
		||||
        else:
 | 
			
		||||
            if type(config) in AutoModelForVision2Seq._model_mapping.keys():  # assume built-in models
 | 
			
		||||
                load_class = AutoModelForVision2Seq
 | 
			
		||||
            elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
 | 
			
		||||
                load_class = AutoModelForSeq2SeqLM
 | 
			
		||||
            else:
 | 
			
		||||
                load_class = AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -280,6 +280,12 @@ _register_composite_model(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="qwen2_audio",
 | 
			
		||||
    vision_model_keys=["audio_tower"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="qwen2_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
 | 
			
		||||
@ -78,13 +78,14 @@ def patch_processor(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
) -> None:
 | 
			
		||||
    setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
    setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
    setattr(processor, "image_resolution", model_args.image_resolution)
 | 
			
		||||
    setattr(processor, "patch_size", get_patch_size(config, processor))
 | 
			
		||||
    setattr(processor, "video_resolution", model_args.video_resolution)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
    setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
    setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
 | 
			
		||||
    if getattr(config, "vision_config", None) is not None:  # visual models
 | 
			
		||||
        setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
        setattr(processor, "image_resolution", model_args.image_resolution)
 | 
			
		||||
        setattr(processor, "patch_size", get_patch_size(config, processor))
 | 
			
		||||
        setattr(processor, "video_resolution", model_args.video_resolution)
 | 
			
		||||
        setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
        setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
        setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_config(
 | 
			
		||||
 | 
			
		||||
@ -172,6 +172,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
        tools: str,
 | 
			
		||||
        image: Optional[Any],
 | 
			
		||||
        video: Optional[Any],
 | 
			
		||||
        audio: Optional[Any],
 | 
			
		||||
        max_new_tokens: int,
 | 
			
		||||
        top_p: float,
 | 
			
		||||
        temperature: float,
 | 
			
		||||
@ -190,6 +191,7 @@ class WebChatModel(ChatModel):
 | 
			
		||||
            tools,
 | 
			
		||||
            images=[image] if image else None,
 | 
			
		||||
            videos=[video] if video else None,
 | 
			
		||||
            audios=[audio] if audio else None,
 | 
			
		||||
            max_new_tokens=max_new_tokens,
 | 
			
		||||
            top_p=top_p,
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
 | 
			
		||||
@ -26,9 +26,9 @@ from ..extras import logging
 | 
			
		||||
from ..extras.constants import (
 | 
			
		||||
    DATA_CONFIG,
 | 
			
		||||
    DEFAULT_TEMPLATE,
 | 
			
		||||
    MULTIMODAL_SUPPORTED_MODELS,
 | 
			
		||||
    SUPPORTED_MODELS,
 | 
			
		||||
    TRAINING_ARGS,
 | 
			
		||||
    VISION_MODELS,
 | 
			
		||||
    DownloadSource,
 | 
			
		||||
)
 | 
			
		||||
from ..extras.misc import use_modelscope, use_openmind
 | 
			
		||||
@ -136,13 +136,6 @@ def get_template(model_name: str) -> str:
 | 
			
		||||
    return DEFAULT_TEMPLATE.get(model_name, "default")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_visual(model_name: str) -> bool:
 | 
			
		||||
    r"""
 | 
			
		||||
    Judges if the model is a vision language model.
 | 
			
		||||
    """
 | 
			
		||||
    return model_name in VISION_MODELS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_time() -> str:
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets current date and time.
 | 
			
		||||
@ -150,6 +143,13 @@ def get_time() -> str:
 | 
			
		||||
    return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_multimodal(model_name: str) -> bool:
 | 
			
		||||
    r"""
 | 
			
		||||
    Judges if the model is a vision language model.
 | 
			
		||||
    """
 | 
			
		||||
    return model_name in MULTIMODAL_SUPPORTED_MODELS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads dataset_info.json.
 | 
			
		||||
 | 
			
		||||
@ -64,10 +64,13 @@ def create_chat_box(
 | 
			
		||||
 | 
			
		||||
                    with gr.Column() as mm_box:
 | 
			
		||||
                        with gr.Tab("Image"):
 | 
			
		||||
                            image = gr.Image(sources=["upload"], type="pil")
 | 
			
		||||
                            image = gr.Image(type="pil")
 | 
			
		||||
 | 
			
		||||
                        with gr.Tab("Video"):
 | 
			
		||||
                            video = gr.Video(sources=["upload"])
 | 
			
		||||
                            video = gr.Video()
 | 
			
		||||
 | 
			
		||||
                        with gr.Tab("Audio"):
 | 
			
		||||
                            audio = gr.Audio(type="filepath")
 | 
			
		||||
 | 
			
		||||
                query = gr.Textbox(show_label=False, lines=8)
 | 
			
		||||
                submit_btn = gr.Button(variant="primary")
 | 
			
		||||
@ -86,7 +89,7 @@ def create_chat_box(
 | 
			
		||||
        [chatbot, messages, query],
 | 
			
		||||
    ).then(
 | 
			
		||||
        engine.chatter.stream,
 | 
			
		||||
        [chatbot, messages, lang, system, tools, image, video, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages, lang, system, tools, image, video, audio, max_new_tokens, top_p, temperature],
 | 
			
		||||
        [chatbot, messages],
 | 
			
		||||
    )
 | 
			
		||||
    clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
 | 
			
		||||
@ -102,6 +105,7 @@ def create_chat_box(
 | 
			
		||||
            mm_box=mm_box,
 | 
			
		||||
            image=image,
 | 
			
		||||
            video=video,
 | 
			
		||||
            audio=audio,
 | 
			
		||||
            query=query,
 | 
			
		||||
            submit_btn=submit_btn,
 | 
			
		||||
            max_new_tokens=max_new_tokens,
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
from typing import TYPE_CHECKING, Dict
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
from ..common import get_visual
 | 
			
		||||
from ..common import is_multimodal
 | 
			
		||||
from .chatbot import create_chat_box
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
    ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])
 | 
			
		||||
 | 
			
		||||
    engine.manager.get_elem_by_id("top.model_name").change(
 | 
			
		||||
        lambda model_name: gr.Column(visible=get_visual(model_name)),
 | 
			
		||||
        lambda model_name: gr.Column(visible=is_multimodal(model_name)),
 | 
			
		||||
        [engine.manager.get_elem_by_id("top.model_name")],
 | 
			
		||||
        [chat_elems["mm_box"]],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -52,12 +52,16 @@ NO_IMAGES = []
 | 
			
		||||
 | 
			
		||||
NO_VIDEOS = []
 | 
			
		||||
 | 
			
		||||
NO_AUDIOS = []
 | 
			
		||||
 | 
			
		||||
IMGLENS = [1]
 | 
			
		||||
 | 
			
		||||
NO_IMGLENS = [0]
 | 
			
		||||
 | 
			
		||||
NO_VIDLENS = [0]
 | 
			
		||||
 | 
			
		||||
NO_AUDLENS = [0]
 | 
			
		||||
 | 
			
		||||
INPUT_IDS = [0, 1, 2, 3, 4]
 | 
			
		||||
 | 
			
		||||
LABELS = [0, 1, 2, 3, 4]
 | 
			
		||||
@ -99,23 +103,25 @@ def _check_plugin(
 | 
			
		||||
    expected_no_mm_inputs: Dict[str, Any] = {},
 | 
			
		||||
) -> None:
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, tokenizer, processor) == (
 | 
			
		||||
    assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
        expected_input_ids,
 | 
			
		||||
        expected_labels,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
 | 
			
		||||
        plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
        expected_mm_inputs,
 | 
			
		||||
    )
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, tokenizer, processor) == (
 | 
			
		||||
    assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
        INPUT_IDS,
 | 
			
		||||
        LABELS,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
 | 
			
		||||
        plugin.get_mm_inputs(
 | 
			
		||||
            NO_IMAGES, NO_VIDEOS, NO_AUDIOS, NO_IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor
 | 
			
		||||
        ),
 | 
			
		||||
        expected_no_mm_inputs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -167,7 +167,7 @@ def test_phi4_template(use_fast: bool):
 | 
			
		||||
    _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")  # TODO: why it is gated?
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
def test_qwen_template(use_fast: bool):
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user