From 8ffe7daa8d5b64657d3d191ad96bcdf29d42e1de Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 9 Jun 2025 10:37:42 +0800 Subject: [PATCH] [model] support Mistral3.1 small 2503 (#8335) --- src/llamafactory/data/mm_plugin.py | 5 +++-- src/llamafactory/data/template.py | 1 + src/llamafactory/extras/constants.py | 16 ++++++++++++++++ src/llamafactory/model/model_utils/unsloth.py | 7 +++++-- src/llamafactory/model/model_utils/visual.py | 5 +++++ 5 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 446fbe60..3ac9c307 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1274,9 +1274,10 @@ class PixtralPlugin(BasePlugin): content = message["content"] while IMAGE_PLACEHOLDER in content: if self.expand_mm_tokens: + patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1) height, width = next(image_sizes) - num_height_tokens = height // processor.patch_size - num_width_tokens = width // processor.patch_size + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list replace_tokens[-1] = image_end_token diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 634636f7..3f3a3164 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1433,6 +1433,7 @@ register_template( format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]), format_tools=ToolFormatter(tool_format="mistral"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a9d5b7cd..99388461 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1592,6 +1592,22 @@ register_model_group( ) +register_model_group( + models={ + "Mistral-Small-24B-Base-2503": { + DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503", + DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503", + }, + "Mistral-Small-24B-Instruct-2503": { + DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503", + DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503", + }, + }, + template="mistral_small", + multimodal=True, +) + + register_model_group( models={ "Mixtral-8x7B-v0.1": { diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 352ef048..37791524 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -21,14 +21,17 @@ from ...extras.misc import get_current_device if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel - from ...hparams import ModelArguments, FinetuningArguments + from ...hparams import FinetuningArguments, ModelArguments logger = logging.get_logger(__name__) def _get_unsloth_kwargs( - config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments" + config: "PretrainedConfig", + model_name_or_path: str, + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", ) -> dict[str, Any]: return { "model_name": model_name_or_path, diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 247de48b..9d4e535a 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -263,6 +263,11 @@ _register_composite_model( ) +_register_composite_model( + model_type="mistral3", +) + + _register_composite_model( model_type="qwen2_audio", vision_model_keys=["audio_tower"],