mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
Support Inference of MiniCPM-V-2.6 and MiniCPM-o-2.6 (#6631)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv Former-commit-id: 7f3c64e853a7cdd49d02bf85e237611941ac7fa8
This commit is contained in:
parent
d0da6f40b0
commit
f7857c83e1
1
setup.py
1
setup.py
@ -68,6 +68,7 @@ extra_require = {
|
|||||||
"msgpack",
|
"msgpack",
|
||||||
"referencing",
|
"referencing",
|
||||||
"jsonschema_specifications",
|
"jsonschema_specifications",
|
||||||
|
"librosa",
|
||||||
],
|
],
|
||||||
"modelscope": ["modelscope"],
|
"modelscope": ["modelscope"],
|
||||||
"openmind": ["openmind"],
|
"openmind": ["openmind"],
|
||||||
|
@ -168,6 +168,8 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
value = torch.stack(value) # assume they have same sizes
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
||||||
|
value = torch.stack([torch.stack(per_value) for per_value in value])
|
||||||
elif not isinstance(value, torch.Tensor):
|
elif not isinstance(value, torch.Tensor):
|
||||||
value = torch.tensor(value)
|
value = torch.tensor(value)
|
||||||
|
|
||||||
@ -176,6 +178,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
gen_kwargs[key] = value.to(model.device)
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
|
if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]:
|
||||||
|
gen_kwargs["input_ids"] = inputs
|
||||||
|
del gen_kwargs["image_sizes"]
|
||||||
|
gen_kwargs["tokenizer"] = tokenizer
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -207,6 +214,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
|
if isinstance(generate_output, tuple):
|
||||||
|
generate_output = generate_output[1][0] # for minicpm_o
|
||||||
|
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = tokenizer.batch_decode(
|
response = tokenizer.batch_decode(
|
||||||
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
||||||
|
@ -252,6 +252,13 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="minicpmv",
|
model_type="minicpmv",
|
||||||
|
vision_model_keys=["vpm"],
|
||||||
|
language_model_keys=["llm"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="minicpmo",
|
||||||
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
vision_model_keys=["vpm", "apm", "resampler", "tts"],
|
||||||
language_model_keys=["llm"],
|
language_model_keys=["llm"],
|
||||||
)
|
)
|
||||||
|
@ -110,6 +110,10 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "minicpmo":
|
||||||
|
setattr(config, "init_audio", False)
|
||||||
|
setattr(config, "init_tts", False)
|
||||||
|
|
||||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
@ -145,7 +149,9 @@ def patch_model(
|
|||||||
):
|
):
|
||||||
gen_config.do_sample = True
|
gen_config.do_sample = True
|
||||||
|
|
||||||
if "GenerationMixin" not in str(model.generate.__func__):
|
if getattr(model.config, "model_type") not in ["minicpmv", "minicpmo"] and "GenerationMixin" not in str(
|
||||||
|
model.generate.__func__
|
||||||
|
):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user