mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[model] Support Kimi_VL thinking/instruct (#7719)
* add kimi_vl * patch config * check version * Update mm_plugin.py * Update mm_plugin.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
3a13d2cdb1
commit
df8752e8ee
@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class KimiVLPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(self, messages, images, videos, audios, processor):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_grid_hws = mm_inputs.get("image_grid_hws", [])
|
||||
num_image_tokens = 0
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
merge_length = math.prod(image_processor.merge_kernel_size)
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_hws):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama4Plugin(BasePlugin):
|
||||
@override
|
||||
@ -1420,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
content = message["content"]
|
||||
# separate with audio-video
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
@ -1430,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
if not use_audio_in_video:
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
if num_audio_tokens >= len(audio_lengths):
|
||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER,
|
||||
@ -1440,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
# TODO handle video_input and use_audio_in_video
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||
@ -1448,6 +1492,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
|
||||
else: # if use the audio of video # deal video token and audio token togather
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
@ -1471,10 +1518,11 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
if video_chunk_index is not None:
|
||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
@ -1555,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
|
@ -923,6 +923,20 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="kimi_vl",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
|
||||
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
|
||||
default_system="You are a helpful assistant",
|
||||
stop_words=["<|im_end|>"],
|
||||
thought_words=("◁think▷", "◁/think▷"),
|
||||
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
|
@ -975,6 +975,22 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Kimi-VL-A3B-Instruct": {
|
||||
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct",
|
||||
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct",
|
||||
},
|
||||
"Kimi-VL-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||
},
|
||||
},
|
||||
template="kimi_vl",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": {
|
||||
|
@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if model_type in ["kimi_vl", "deepseek_v3"]:
|
||||
check_version("transformers>=4.51.1")
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
||||
|
||||
if model_type == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
@ -117,6 +117,10 @@ def patch_config(
|
||||
setattr(config, "init_audio", True)
|
||||
setattr(config, "init_tts", False)
|
||||
|
||||
# replace the top-k gating method
|
||||
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
||||
setattr(config.text_config, "topk_method", "greedy")
|
||||
|
||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user