From 4fd94141a42d99df399ab64607d6c7d6bfbfe7db Mon Sep 17 00:00:00 2001 From: tangefly <124695565+tangefly@users.noreply.github.com> Date: Wed, 10 Dec 2025 15:57:24 +0800 Subject: [PATCH] [model] Add Ministral3 (#9582) Co-authored-by: kingsley --- README.md | 2 +- README_zh.md | 2 +- src/llamafactory/data/template.py | 13 ++++++++++++ src/llamafactory/extras/constants.py | 20 +++++++++++++++++++ .../model/model_utils/quantization.py | 7 ++++++- src/llamafactory/model/model_utils/visual.py | 1 + src/llamafactory/model/patcher.py | 2 +- src/llamafactory/train/sft/workflow.py | 2 +- 8 files changed, 44 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c6b8b1d4..7e3c0c89 100644 --- a/README.md +++ b/README.md @@ -315,7 +315,7 @@ Read technical notes: | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | -| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | +| [Ministral(3)/Mistral-Nemo](https://huggingface.co/mistralai) | 3B/8B/12B/14B | ministral/ministral3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | diff --git a/README_zh.md b/README_zh.md index ff910371..ba7ed90b 100644 --- a/README_zh.md +++ b/README_zh.md @@ -317,7 +317,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B | mimo | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | -| [Ministral/Mistral-Nemo](https://huggingface.co/mistralai) | 8B/12B | ministral | +| [Ministral(3)/Mistral-Nemo](https://huggingface.co/mistralai) | 3B/8B/12B/14B | ministral/ministral3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral Small](https://huggingface.co/mistralai) | 24B | mistral_small | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 11b5d4f1..36a4d43b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1687,6 +1687,19 @@ register_template( ) +register_template( + name="ministral3", + format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"), + format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]), + format_tools=ToolFormatter(tool_format="mistral"), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + template_class=Llama2Template, + mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), +) + + register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 64eadb4c..e0d4cbd9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -141,6 +141,7 @@ class QuantizationMethod(str, Enum): EETQ = "eetq" HQQ = "hqq" MXFP4 = "mxfp4" + FP8 = "fp8" class RopeScaling(str, Enum): @@ -1977,6 +1978,25 @@ register_model_group( template="mistral", ) +register_model_group( + models={ + "Ministral-3-3B-Instruct-2512": { + DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512", + DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512", + }, + "Ministral-3-8B-Instruct-2512": { + DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512", + DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512", + }, + "Ministral-3-14B-Instruct-2512": { + DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512", + DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512", + }, + }, + template="ministral3", + multimodal=True, +) + register_model_group( models={ diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 8b227b7f..ebffbbc7 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any import torch from datasets import load_dataset -from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig +from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled @@ -83,6 +83,7 @@ def configure_quantization( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", + is_trainable: bool, init_kwargs: dict[str, Any], ) -> None: r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer).""" @@ -109,6 +110,10 @@ def configure_quantization( check_version("aqlm>=1.1.0", mandatory=True) quantization_config["bits"] = 2 + if quant_method == QuantizationMethod.FP8 and is_trainable: + quant_config = FineGrainedFP8Config(dequantize=True) + init_kwargs["quantization_config"] = quant_config + quant_bits = quantization_config.get("bits", "?") logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 4ca64569..955c5229 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -301,6 +301,7 @@ _register_composite_model( _register_composite_model( model_type="mistral3", + projector_key="model.multi_modal_projector", ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index fa2ac832..ca93fad5 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -115,7 +115,7 @@ def patch_config( configure_attn_implementation(config, model_args) configure_rope(config, model_args) configure_longlora(config, model_args, is_trainable) - configure_quantization(config, tokenizer, model_args, init_kwargs) + configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs) configure_moe(config, model_args, is_trainable) configure_visual_model(config) configure_packing(model_args, is_trainable) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 52e0fd12..ebc1301c 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -78,7 +78,7 @@ def run_sft( gen_kwargs = generating_args.to_dict(obey_generation_config=True) # Compatible with Transformers v4 and Transformers v5 - if is_transformers_version_greater_than("5.0.0RC0"): + if is_transformers_version_greater_than("4.58.0"): extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None) if not isinstance(extra_ids, list): extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])