diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index eb8e7e5c..a0d03699 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -543,7 +543,7 @@ class MiniCPMVPlugin(BasePlugin): self._validate_input(images, videos) image_bounds_list = [] valid_image_nums_ls = [] - for input_ids in batch_ids: + for i, input_ids in enumerate(batch_ids): input_ids_ = torch.tensor(input_ids) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( input_ids_ == processor.tokenizer.slice_start_id @@ -552,12 +552,11 @@ class MiniCPMVPlugin(BasePlugin): image_start_tokens = torch.where(start_cond)[0] image_start_tokens += 1 image_end_tokens = torch.where(end_cond)[0] - valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) - valid_image_nums_ls.append(valid_image_nums) + valid_image_nums_ls.append(imglens[i]) image_bounds = torch.hstack( [ - image_start_tokens[:valid_image_nums].unsqueeze(-1), - image_end_tokens[:valid_image_nums].unsqueeze(-1), + image_start_tokens.unsqueeze(-1), + image_end_tokens.unsqueeze(-1), ] ) image_bounds_list.append(image_bounds) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1f628c2f..cf110af9 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -125,7 +125,7 @@ def patch_config( setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn if getattr(config, "model_type", None) == "minicpmo": - setattr(config, "init_audio", False) + setattr(config, "init_audio", True) setattr(config, "init_tts", False) if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): diff --git a/tests/data/test_template.py b/tests/data/test_template.py index dead0af0..e6b6ed2b 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -152,6 +152,7 @@ def test_llama3_template(use_fast: bool): _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) +@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize( "use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))] ) @@ -166,6 +167,7 @@ def test_phi4_template(use_fast: bool): _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) +@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_qwen_template(use_fast: bool): prompt_str = (