[data] fix minicpmv plugin (#6801)

* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

* update readme

* support dpo of minicpmv

* update init audio

* update init audio

* [model]fix image process in minicpmo

Former-commit-id: ab9bd068efee861452407cdda08ef014d5ce23d5
This commit is contained in:
Zhangchi Feng 2025-02-04 21:20:15 +08:00 committed by GitHub
parent 822d5d362c
commit 85f22d01bf
3 changed files with 7 additions and 6 deletions

View File

@ -543,7 +543,7 @@ class MiniCPMVPlugin(BasePlugin):
self._validate_input(images, videos)
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
for i, input_ids in enumerate(batch_ids):
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
@ -552,12 +552,11 @@ class MiniCPMVPlugin(BasePlugin):
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
valid_image_nums_ls.append(imglens[i])
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
image_start_tokens.unsqueeze(-1),
image_end_tokens.unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)

View File

@ -125,7 +125,7 @@ def patch_config(
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_audio", True)
setattr(config, "init_tts", False)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):

View File

@ -152,6 +152,7 @@ def test_llama3_template(use_fast: bool):
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize(
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
)
@ -166,6 +167,7 @@ def test_phi4_template(use_fast: bool):
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
@pytest.mark.parametrize("use_fast", [True, False])
def test_qwen_template(use_fast: bool):
prompt_str = (