mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
This commit is contained in:
@@ -109,6 +109,10 @@ def patch_config(
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", False)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
@@ -145,7 +149,9 @@ def patch_model(
|
||||
):
|
||||
gen_config.do_sample = True
|
||||
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||
model.generate.__func__
|
||||
):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
if add_valuehead:
|
||||
|
||||
Reference in New Issue
Block a user