Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)

* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
This commit is contained in:
Zhangchi Feng
2025-01-14 17:34:58 +08:00
committed by GitHub
parent ff8ef6f52c
commit 068d44b509
4 changed files with 25 additions and 1 deletions

View File

@@ -109,6 +109,10 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
@@ -145,7 +149,9 @@ def patch_model(
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)
if add_valuehead: