mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[model] Support Kimi_VL thinking/instruct (#7719)
* add kimi_vl * patch config * check version * Update mm_plugin.py * Update mm_plugin.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
3a13d2cdb1
commit
df8752e8ee
@ -466,6 +466,41 @@ class Gemma3Plugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KimiVLPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(self, messages, images, videos, audios, processor):
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
image_grid_hws = mm_inputs.get("image_grid_hws", [])
|
||||||
|
num_image_tokens = 0
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
merge_length = math.prod(image_processor.merge_kernel_size)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if num_image_tokens >= len(image_grid_hws):
|
||||||
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
image_seqlen = image_grid_hws[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER,
|
||||||
|
f"<|media_start|>image<|media_content|>{self.image_token * image_seqlen}<|media_end|>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama4Plugin(BasePlugin):
|
class Llama4Plugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -1420,6 +1455,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
content = message["content"]
|
content = message["content"]
|
||||||
# separate with audio-video
|
# separate with audio-video
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if num_image_tokens >= len(image_grid_thw):
|
||||||
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
IMAGE_PLACEHOLDER,
|
IMAGE_PLACEHOLDER,
|
||||||
@ -1430,6 +1468,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
|
|
||||||
if not use_audio_in_video:
|
if not use_audio_in_video:
|
||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
|
if num_audio_tokens >= len(audio_lengths):
|
||||||
|
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
AUDIO_PLACEHOLDER,
|
AUDIO_PLACEHOLDER,
|
||||||
@ -1440,6 +1481,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
|
|
||||||
# TODO handle video_input and use_audio_in_video
|
# TODO handle video_input and use_audio_in_video
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
||||||
@ -1448,6 +1492,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
|
|
||||||
else: # if use the audio of video # deal video token and audio token togather
|
else: # if use the audio of video # deal video token and audio token togather
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if num_video_tokens >= len(video_grid_thw):
|
||||||
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||||
video_t_index = (
|
video_t_index = (
|
||||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||||
@ -1471,10 +1518,11 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||||
if video_chunk_index is not None:
|
if video_chunk_index is not None:
|
||||||
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
||||||
|
|
||||||
if audio_chunk_index is not None:
|
if audio_chunk_index is not None:
|
||||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
|
||||||
|
|
||||||
|
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||||
num_audio_tokens += 1
|
num_audio_tokens += 1
|
||||||
@ -1555,6 +1603,7 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"gemma3": Gemma3Plugin,
|
"gemma3": Gemma3Plugin,
|
||||||
|
"kimi_vl": KimiVLPlugin,
|
||||||
"llama4": Llama4Plugin,
|
"llama4": Llama4Plugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
|
@ -923,6 +923,20 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="kimi_vl",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=["<|im_user|>user<|im_middle|>{{content}}<|im_end|><|im_assistant|>assistant<|im_middle|>"]
|
||||||
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_system|>system<|im_middle|>{{content}}<|im_end|>"]),
|
||||||
|
default_system="You are a helpful assistant",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
thought_words=("◁think▷", "◁/think▷"),
|
||||||
|
mm_plugin=get_mm_plugin("kimi_vl", image_token="<|media_pad|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
|
@ -975,6 +975,22 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Kimi-VL-A3B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
|
},
|
||||||
|
"Kimi-VL-A3B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||||
|
DownloadSource.MODELSCOPE: "moonshotai/Kimi-VL-A3B-Thinking",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="kimi_vl",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LingoWhale-8B": {
|
"LingoWhale-8B": {
|
||||||
|
@ -54,6 +54,12 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||||
|
|
||||||
|
if model_type in ["kimi_vl", "deepseek_v3"]:
|
||||||
|
check_version("transformers>=4.51.1")
|
||||||
|
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [DeepseekV3MoE])
|
||||||
|
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
@ -117,6 +117,10 @@ def patch_config(
|
|||||||
setattr(config, "init_audio", True)
|
setattr(config, "init_audio", True)
|
||||||
setattr(config, "init_tts", False)
|
setattr(config, "init_tts", False)
|
||||||
|
|
||||||
|
# replace the top-k gating method
|
||||||
|
if getattr(config, "model_type", None) == "kimi_vl" and is_trainable:
|
||||||
|
setattr(config.text_config, "topk_method", "greedy")
|
||||||
|
|
||||||
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
if "LlavaLlamaForCausalLM" in getattr(config, "architectures", []):
|
||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user