mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] fix minicpmv plugin (#6801)
* fix template name * tiny fix * support minicpm-o-2.6 * support inference of minicpmv * update readme * support dpo of minicpmv * update init audio * update init audio * [model]fix image process in minicpmo Former-commit-id: ab9bd068efee861452407cdda08ef014d5ce23d5
This commit is contained in:
parent
822d5d362c
commit
85f22d01bf
@ -543,7 +543,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
for input_ids in batch_ids:
|
||||
for i, input_ids in enumerate(batch_ids):
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||
input_ids_ == processor.tokenizer.slice_start_id
|
||||
@ -552,12 +552,11 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
image_start_tokens = torch.where(start_cond)[0]
|
||||
image_start_tokens += 1
|
||||
image_end_tokens = torch.where(end_cond)[0]
|
||||
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||
valid_image_nums_ls.append(valid_image_nums)
|
||||
valid_image_nums_ls.append(imglens[i])
|
||||
image_bounds = torch.hstack(
|
||||
[
|
||||
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||
image_start_tokens.unsqueeze(-1),
|
||||
image_end_tokens.unsqueeze(-1),
|
||||
]
|
||||
)
|
||||
image_bounds_list.append(image_bounds)
|
||||
|
@ -125,7 +125,7 @@ def patch_config(
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "minicpmo":
|
||||
setattr(config, "init_audio", False)
|
||||
setattr(config, "init_audio", True)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
|
@ -152,6 +152,7 @@ def test_llama3_template(use_fast: bool):
|
||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize(
|
||||
"use_fast", [True, pytest.param(False, marks=pytest.mark.xfail(reason="Phi-4 slow tokenizer is broken."))]
|
||||
)
|
||||
@ -166,6 +167,7 @@ def test_phi4_template(use_fast: bool):
|
||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||
@pytest.mark.parametrize("use_fast", [True, False])
|
||||
def test_qwen_template(use_fast: bool):
|
||||
prompt_str = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user