From df8752e8ee12415bf72823c8fe544ffe74479c7f Mon Sep 17 00:00:00 2001 From: Kingsley <82590017+Kuangdd01@users.noreply.github.com> Date: Tue, 15 Apr 2025 00:21:58 +0800 Subject: [PATCH] [model] Support Kimi_VL thinking/instruct (#7719) * add kimi_vl * patch config * check version * Update mm_plugin.py * Update mm_plugin.py --------- Co-authored-by: hoshi-hiyouga --- src/llamafactory/data/mm_plugin.py | 51 ++++++++++++++++++++++- src/llamafactory/data/template.py | 14 +++++++ src/llamafactory/extras/constants.py | 16 +++++++ src/llamafactory/model/model_utils/moe.py | 6 +++ src/llamafactory/model/patcher.py | 4 ++ 5 files changed, 90 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e56f5dd2..e07b764d 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin): return mm_inputs +@dataclass +class KimiVLPlugin(BasePlugin): + @override + def process_messages(self, messages, images, videos, audios, processor): + self._validate_input(processor, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + image_grid_hws = mm_inputs.get("image_grid_hws", []) + num_image_tokens = 0 + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + merge_length = math.prod(image_processor.merge_kernel_size) + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(image_grid_hws): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + + image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, + f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>", + 1, + ) + num_image_tokens += 1 + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages + + @dataclass class Llama4Plugin(BasePlugin): @override @@ -1420,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): content = message["content"] # separate with audio-video while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(image_grid_thw): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length content = content.replace( IMAGE_PLACEHOLDER, @@ -1430,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): if not use_audio_in_video: while AUDIO_PLACEHOLDER in content: + if num_audio_tokens >= len(audio_lengths): + raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.") + audio_token_replace_length = audio_lengths[num_audio_tokens] content = content.replace( AUDIO_PLACEHOLDER, @@ -1440,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): # TODO handle video_input and use_audio_in_video while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(video_grid_thw): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length content = content.replace( VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1 @@ -1448,6 +1492,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): else: # if use the audio of video # deal video token and audio token togather while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(video_grid_thw): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + audio_t_index = torch.arange(audio_lengths[num_audio_tokens]) video_t_index = ( torch.arange(video_grid_thw[num_video_tokens][0]) @@ -1471,10 +1518,11 @@ class Qwen2OmniPlugin(Qwen2VLPlugin): audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None if video_chunk_index is not None: placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0]) + if audio_chunk_index is not None: placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0]) - placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" + placeholder_string += "<|audio_eos|>" + "<|vision_eos|>" content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1) content = content.replace(AUDIO_PLACEHOLDER, "", 1) num_audio_tokens += 1 @@ -1555,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, "gemma3": Gemma3Plugin, + "kimi_vl": KimiVLPlugin, "llama4": Llama4Plugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 6a84cbf9..0a334787 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -923,6 +923,20 @@ register_template( ) +register_template( + name="kimi_vl", + format_user=StringFormatter( + slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"] + ), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]), + format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]), + default_system="You are a helpful assistant", + stop_words=["<|im_end|>"], + thought_words=("◁think▷", "◁/think▷"), + mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"), +) + + register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e10b6b7f..ffa5cefa 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -975,6 +975,22 @@ register_model_group( ) +register_model_group( + models={ + "Kimi-VL-A3B-Instruct": { + DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct", + DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct", + }, + "Kimi-VL-A3B-Thinking": { + DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking", + DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking", + }, + }, + template="kimi_vl", + multimodal=True, +) + + register_model_group( models={ "LingoWhale-8B": { diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index bc4f2906..b3fca4f7 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + if model_type in ["kimi_vl", "deepseek_v3"]: + check_version("transformers>=4.51.1") + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + + _set_z3_leaf_modules(model, [DeepseekV3MoE]) + if model_type == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1706d177..996fe7ef 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -117,6 +117,10 @@ def patch_config( setattr(config, "init_audio", True) setattr(config, "init_tts", False) + # replace the top-k gating method + if getattr(config, "model_type", None) == "kimi_vl" and is_trainable: + setattr(config.text_config, "topk_method", "greedy") + if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []): raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")