[da'ta] fix minicpmv plugin (#6890)

* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

* support dpo of minicpmv

* update init audio

* update init audio

* [model]fix image process in minicpmo

* fix no mm inputs

Former-commit-id: 764627645abcd353f9130d5dd8c584810b0e0b1b
This commit is contained in:
Zhangchi Feng 2025-02-11 13:30:44 +08:00 committed by GitHub
parent fe4f4e9758
commit 5433b318bb
5 changed files with 117 additions and 98 deletions

View File

@ -215,7 +215,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| ----------------------------------------------------------------- | -------------------------------- | ---------------- | | ----------------------------------------------------------------- | -------------------------------- | ------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
@ -238,7 +238,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small | | [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |

View File

@ -217,7 +217,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| ----------------------------------------------------------------- | -------------------------------- | ---------------- | | ----------------------------------------------------------------- | -------------------------------- | ------------------- |
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
@ -240,7 +240,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v |
| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | | [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small | | [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small |

View File

@ -106,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_audlens.append(len(audios)) batch_audlens.append(len(audios))
batch_input_ids.append(feature["input_ids"]) batch_input_ids.append(feature["input_ids"])
fake_input_ids = None fake_input_ids = []
if ( if (
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case ): # avoid process hanging in zero3/fsdp case
@ -115,10 +115,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_messages = self.template.mm_plugin.process_messages( fake_messages = self.template.mm_plugin.process_messages(
fake_messages, fake_images, [], [], self.processor fake_messages, fake_images, [], [], self.processor
) )
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids( _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor _fake_input_ids, None, fake_images, [], [], self.tokenizer, self.processor
) )
fake_input_ids.extend(_fake_input_ids)
batch_images = fake_images batch_images = fake_images
batch_imglens[0] = 1 batch_imglens[0] = 1
@ -130,14 +131,15 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
fake_messages = self.template.mm_plugin.process_messages( fake_messages = self.template.mm_plugin.process_messages(
fake_messages, [], [], fake_audios, self.processor fake_messages, [], [], fake_audios, self.processor
) )
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False) _fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
fake_input_ids, _ = self.template.mm_plugin.process_token_ids( _fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor _fake_input_ids, None, [], [], fake_audios, self.tokenizer, self.processor
) )
fake_input_ids.extend(_fake_input_ids)
batch_audios = fake_audios batch_audios = fake_audios
batch_audlens[0] = 1 batch_audlens[0] = 1
if fake_input_ids is not None: if len(fake_input_ids) != 0:
if self.tokenizer.padding_side == "right": if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids) features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)

View File

@ -645,6 +645,12 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True, chunk_input=True,
sampling_rate=16000, sampling_rate=16000,
) )
audio_feature_lens = [
torch.tensor(audio_feature_len)
if not isinstance(audio_feature_len, torch.Tensor)
else audio_feature_len
for audio_feature_len in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False): if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs}) mm_inputs.update({"audio_phs": audio_phs})

View File

@ -982,6 +982,17 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>"),
)
# copied from chatml template
_register_template(
name="minicpm_o",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"), mm_plugin=get_mm_plugin(name="minicpm_v", image_token="<image>", video_token="<video>", audio_token="<audio>"),
) )