[model] support Mistral3.1 small 2503 (#8335)

This commit is contained in:
Kingsley 2025-06-09 10:37:42 +08:00 committed by GitHub
parent 8fa55db1ec
commit 8ffe7daa8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 30 additions and 4 deletions

View File

@ -1274,9 +1274,10 @@ class PixtralPlugin(BasePlugin):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1)
height, width = next(image_sizes)
num_height_tokens = height // processor.patch_size
num_width_tokens = width // processor.patch_size
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token

View File

@ -1433,6 +1433,7 @@ register_template(
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)

View File

@ -1592,6 +1592,22 @@ register_model_group(
)
register_model_group(
models={
"Mistral-Small-24B-Base-2503": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
},
"Mistral-Small-24B-Instruct-2503": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
},
},
template="mistral_small",
multimodal=True,
)
register_model_group(
models={
"Mixtral-8x7B-v0.1": {

View File

@ -21,14 +21,17 @@ from ...extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments, FinetuningArguments
from ...hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
config: "PretrainedConfig",
model_name_or_path: str,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
) -> dict[str, Any]:
return {
"model_name": model_name_or_path,

View File

@ -263,6 +263,11 @@ _register_composite_model(
)
_register_composite_model(
model_type="mistral3",
)
_register_composite_model(
model_type="qwen2_audio",
vision_model_keys=["audio_tower"],