[model] Support Kimi_VL thinking/instruct (#7719)

* add kimi_vl

* patch config

* check version

* Update mm_plugin.py

* Update mm_plugin.py

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Kingsley 2025-04-15 00:21:58 +08:00 committed by GitHub
parent 3a13d2cdb1
commit df8752e8ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 90 additions and 1 deletions

View File

@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
return mm_inputs
@dataclass
class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_hws = mm_inputs.get("image_grid_hws", [])
num_image_tokens = 0
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
merge_length = math.prod(image_processor.merge_kernel_size)
messages = deepcopy(messages)
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_hws):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
content = content.replace(
IMAGE_PLACEHOLDER,
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
1,
)
num_image_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass
class Llama4Plugin(BasePlugin):
@override
@ -1420,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
content = message["content"]
# separate with audio-video
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_grid_thw):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
content = content.replace(
IMAGE_PLACEHOLDER,
@ -1430,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
if not use_audio_in_video:
while AUDIO_PLACEHOLDER in content:
if num_audio_tokens >= len(audio_lengths):
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
audio_token_replace_length = audio_lengths[num_audio_tokens]
content = content.replace(
AUDIO_PLACEHOLDER,
@ -1440,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
# TODO handle video_input and use_audio_in_video
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
content = content.replace(
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
@ -1448,6 +1492,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
else: # if use the audio of video # deal video token and audio token togather
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_grid_thw):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
video_t_index = (
torch.arange(video_grid_thw[num_video_tokens][0])
@ -1471,10 +1518,11 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
if video_chunk_index is not None:
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
if audio_chunk_index is not None:
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
num_audio_tokens += 1
@ -1555,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,

View File

@ -923,6 +923,20 @@ register_template(
)
register_template(
name="kimi_vl",
format_user=StringFormatter(
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
default_system="You are a helpful assistant",
stop_words=["<|im_end|>"],
thought_words=("◁think▷", "◁/think▷"),
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
)
register_template(
name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),

View File

@ -975,6 +975,22 @@ register_model_group(
)
register_model_group(
models={
"Kimi-VL-A3B-Instruct": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct",
},
"Kimi-VL-A3B-Thinking": {
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
},
},
template="kimi_vl",
multimodal=True,
)
register_model_group(
models={
"LingoWhale-8B": {

View File

@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
if model_type in ["kimi_vl", "deepseek_v3"]:
check_version("transformers>=4.51.1")
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
_set_z3_leaf_modules(model, [DeepseekV3MoE])
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

View File

@ -117,6 +117,10 @@ def patch_config(
setattr(config, "init_audio", True)
setattr(config, "init_tts", False)
# replace the top-k gating method
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
setattr(config.text_config, "topk_method", "greedy")
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")