mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv Former-commit-id: 158a127d340d5e4ca23263ffad042f861fd77deb
This commit is contained in:
parent
8f73c75c16
commit
ad119afc58
1
setup.py
1
setup.py
@ -68,6 +68,7 @@ extra_require = {
|
||||
"msgpack",
|
||||
"referencing",
|
||||
"jsonschema_specifications",
|
||||
"librosa",
|
||||
],
|
||||
"modelscope": ["modelscope"],
|
||||
"openmind": ["openmind"],
|
||||
|
@ -168,6 +168,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
for key, value in mm_inputs.items():
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
||||
value = torch.stack([torch.stack(per_value) for per_value in value])
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
@ -176,6 +178,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||
gen_kwargs["input_ids"] = inputs
|
||||
del gen_kwargs["image_sizes"]
|
||||
gen_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
@ -207,6 +214,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
if isinstance(generate_output, tuple):
|
||||
generate_output = generate_output[1][0] # for minicpm_o
|
||||
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = tokenizer.batch_decode(
|
||||
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
||||
|
@ -252,6 +252,13 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmv",
|
||||
vision_model_keys=["vpm"],
|
||||
language_model_keys=["llm"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmo",
|
||||
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
||||
language_model_keys=["llm"],
|
||||
)
|
||||
|
@ -109,6 +109,10 @@ def patch_config(
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", False)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
@ -145,7 +149,9 @@ def patch_model(
|
||||
):
|
||||
gen_config.do_sample = True
|
||||
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||
model.generate.__func__
|
||||
):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
if add_valuehead:
|
||||
|
Loading…
x
Reference in New Issue
Block a user