Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)

* fix template name

* tiny fix

* support minicpm-o-2.6

* support inference of minicpmv

Former-commit-id: 158a127d340d5e4ca23263ffad042f861fd77deb
This commit is contained in:
Zhangchi Feng 2025-01-14 17:34:58 +08:00 committed by GitHub
parent 8f73c75c16
commit ad119afc58
4 changed files with 25 additions and 1 deletions

View File

@ -68,6 +68,7 @@ extra_require = {
"msgpack",
"referencing",
"jsonschema_specifications",
"librosa",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],

View File

@ -168,6 +168,8 @@ class HuggingfaceEngine(BaseEngine):
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
value = torch.stack([torch.stack(per_value) for per_value in value])
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
@ -176,6 +178,11 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs[key] = value.to(model.device)
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
gen_kwargs["input_ids"] = inputs
del gen_kwargs["image_sizes"]
gen_kwargs["tokenizer"] = tokenizer
return gen_kwargs, prompt_length
@staticmethod
@ -207,6 +214,9 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
generate_output = model.generate(**gen_kwargs)
if isinstance(generate_output, tuple):
generate_output = generate_output[1][0] # for minicpm_o
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True

View File

@ -252,6 +252,13 @@ _register_composite_model(
_register_composite_model(
model_type="minicpmv",
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)
_register_composite_model(
model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"],
language_model_keys=["llm"],
)

View File

@ -109,6 +109,10 @@ def patch_config(
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
if getattr(config, "model_type", None) == "minicpmo":
setattr(config, "init_audio", False)
setattr(config, "init_tts", False)
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
@ -145,7 +149,9 @@ def patch_model(
):
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__):
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
model.generate.__func__
):
model.generate = MethodType(PreTrainedModel.generate, model)
if add_valuehead: