From 01915eaf4015490394d4360318eb07951adb02c4 Mon Sep 17 00:00:00 2001 From: Zhangchi Feng <64362896+BUAADreamer@users.noreply.github.com> Date: Wed, 5 Feb 2025 04:59:09 +0800 Subject: [PATCH] [model] support audio (#6701) * support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga Former-commit-id: 24c78429489809873a1269a735ea5421340b32a2 --- README.md | 6 +- README_zh.md | 6 +- data/README.md | 48 ++- data/README_zh.md | 47 +++ data/dataset_info.json | 14 + data/mllm_audio_demo.json | 47 +++ data/mllm_demo_data/1.mp3 | Bin 0 -> 129024 bytes data/mllm_demo_data/2.wav | Bin 0 -> 92886 bytes data/mllm_demo_data/3.flac | Bin 0 -> 120041 bytes requirements.txt | 1 + scripts/stat_utils/cal_ppl.py | 1 + setup.py | 1 - src/llamafactory/chat/base_engine.py | 4 +- src/llamafactory/chat/chat_model.py | 16 +- src/llamafactory/chat/hf_engine.py | 33 +- src/llamafactory/chat/vllm_engine.py | 20 +- src/llamafactory/data/aligner.py | 65 ++-- src/llamafactory/data/collator.py | 57 ++- src/llamafactory/data/mm_plugin.py | 344 ++++++++++++++---- src/llamafactory/data/parser.py | 3 +- src/llamafactory/data/processors/feedback.py | 17 +- src/llamafactory/data/processors/pairwise.py | 13 +- .../data/processors/supervised.py | 17 +- .../data/processors/unsupervised.py | 9 +- src/llamafactory/data/template.py | 16 +- src/llamafactory/extras/constants.py | 68 ++-- src/llamafactory/extras/packages.py | 4 + src/llamafactory/hparams/data_args.py | 8 +- src/llamafactory/model/loader.py | 11 +- src/llamafactory/model/model_utils/visual.py | 6 + src/llamafactory/model/patcher.py | 15 +- src/llamafactory/webui/chatter.py | 2 + src/llamafactory/webui/common.py | 16 +- src/llamafactory/webui/components/chatbot.py | 10 +- src/llamafactory/webui/components/infer.py | 4 +- tests/data/test_mm_plugin.py | 18 +- tests/data/test_template.py | 2 +- 37 files changed, 736 insertions(+), 213 deletions(-) create mode 100644 data/mllm_audio_demo.json create mode 100644 data/mllm_demo_data/1.mp3 create mode 100644 data/mllm_demo_data/2.wav create mode 100644 data/mllm_demo_data/3.flac diff --git a/README.md b/README.md index efb4a55b..b79c039c 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,9 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. +- **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. +- **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -105,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/02/05] We supported fine-tuning the **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** and **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** on audio understanding tasks. + [25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model. [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. @@ -247,6 +250,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/README_zh.md b/README_zh.md index ac736767..ae6a68dc 100644 --- a/README_zh.md +++ b/README_zh.md @@ -78,8 +78,9 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 -- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 +- **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 +- **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -115,6 +116,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
展开日志 +[25/02/05] 我们支持了在语音理解任务上微调 **[Qwen2-Audio](Qwen/Qwen2-Audio-7B-Instruct)** 和 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 模型。 + [25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 @@ -249,6 +252,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/data/README.md b/data/README.md index 1786804f..913a15a5 100644 --- a/data/README.md +++ b/data/README.md @@ -24,6 +24,7 @@ Currently we support datasets in **alpaca** and **sharegpt** format. "tools": "the column name in the dataset containing the tool description. (default: None)", "images": "the column name in the dataset containing the image inputs. (default: None)", "videos": "the column name in the dataset containing the videos inputs. (default: None)", + "audios": "the column name in the dataset containing the audios inputs. (default: None)", "chosen": "the column name in the dataset containing the chosen answers. (default: None)", "rejected": "the column name in the dataset containing the rejected answers. (default: None)", "kto_tag": "the column name in the dataset containing the kto tags. (default: None)" @@ -150,6 +151,10 @@ An additional column `images` is required. Please refer to the [sharegpt](#share An additional column `videos` is required. Please refer to the [sharegpt](#sharegpt-format) format for details. +### Multimodal Audio Dataset + +An additional column `audios` is required. Please refer to the [sharegpt](#sharegpt-format) format for details. + ## Sharegpt Format ### Supervised Fine-Tuning Dataset @@ -296,7 +301,7 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh - [Example dataset](mllm_demo.json) -Multimodal image datasets require a `images` column containing the paths to the input images. +Multimodal image datasets require an `images` column containing the paths to the input images. The number of images should be identical to the `` tokens in the conversations. @@ -374,6 +379,47 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh } ``` +### Multimodal Audio Dataset + +- [Example dataset](mllm_audio_demo.json) + +Multimodal audio datasets require an `audios` column containing the paths to the input audios. + +The number of audios should be identical to the `