[model] Add Ministral3 (#9582)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
tangefly
2025-12-10 15:57:24 +08:00
committed by GitHub
parent 22d6ac29d5
commit 4fd94141a4
8 changed files with 44 additions and 5 deletions

View File

@@ -1687,6 +1687,19 @@ register_template(
)
register_template(
name="ministral3",
format_user=StringFormatter(slots=["[INST]{{content}}[/INST]"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_function=FunctionFormatter(slots=["[TOOL_CALLS]{{content}}", {"eos_token"}], tool_format="mistral"),
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
format_tools=ToolFormatter(tool_format="mistral"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
template_class=Llama2Template,
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
)
register_template(
name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),

View File

@@ -141,6 +141,7 @@ class QuantizationMethod(str, Enum):
EETQ = "eetq"
HQQ = "hqq"
MXFP4 = "mxfp4"
FP8 = "fp8"
class RopeScaling(str, Enum):
@@ -1977,6 +1978,25 @@ register_model_group(
template="mistral",
)
register_model_group(
models={
"Ministral-3-3B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-3B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-3B-Instruct-2512",
},
"Ministral-3-8B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-8B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-8B-Instruct-2512",
},
"Ministral-3-14B-Instruct-2512": {
DownloadSource.DEFAULT: "mistralai/Ministral-3-14B-Instruct-2512",
DownloadSource.MODELSCOPE: "mistralai/Ministral-3-14B-Instruct-2512",
},
},
template="ministral3",
multimodal=True,
)
register_model_group(
models={

View File

@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any
import torch
from datasets import load_dataset
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
from transformers import BitsAndBytesConfig, EetqConfig, FineGrainedFP8Config, GPTQConfig, HqqConfig
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
@@ -83,6 +83,7 @@ def configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool,
init_kwargs: dict[str, Any],
) -> None:
r"""Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)."""
@@ -109,6 +110,10 @@ def configure_quantization(
check_version("aqlm>=1.1.0", mandatory=True)
quantization_config["bits"] = 2
if quant_method == QuantizationMethod.FP8 and is_trainable:
quant_config = FineGrainedFP8Config(dequantize=True)
init_kwargs["quantization_config"] = quant_config
quant_bits = quantization_config.get("bits", "?")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")

View File

@@ -301,6 +301,7 @@ _register_composite_model(
_register_composite_model(
model_type="mistral3",
projector_key="model.multi_modal_projector",
)

View File

@@ -115,7 +115,7 @@ def patch_config(
configure_attn_implementation(config, model_args)
configure_rope(config, model_args)
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
configure_moe(config, model_args, is_trainable)
configure_visual_model(config)
configure_packing(model_args, is_trainable)

View File

@@ -78,7 +78,7 @@ def run_sft(
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
# Compatible with Transformers v4 and Transformers v5
if is_transformers_version_greater_than("5.0.0RC0"):
if is_transformers_version_greater_than("4.58.0"):
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if not isinstance(extra_ids, list):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])