mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
[model] support Mistral3.1 small 2503 (#8335)
This commit is contained in:
parent
8fa55db1ec
commit
8ffe7daa8d
@ -1274,9 +1274,10 @@ class PixtralPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if self.expand_mm_tokens:
|
||||
patch_size = processor.patch_size * getattr(processor, "spatial_merge_size", 1)
|
||||
height, width = next(image_sizes)
|
||||
num_height_tokens = height // processor.patch_size
|
||||
num_width_tokens = width // processor.patch_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
|
@ -1433,6 +1433,7 @@ register_template(
|
||||
format_observation=StringFormatter(slots=["""[TOOL_RESULTS]{"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||
format_tools=ToolFormatter(tool_format="mistral"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1592,6 +1592,22 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-Small-24B-Base-2503": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
|
||||
},
|
||||
"Mistral-Small-24B-Instruct-2503": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
|
||||
},
|
||||
},
|
||||
template="mistral_small",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mixtral-8x7B-v0.1": {
|
||||
|
@ -21,14 +21,17 @@ from ...extras.misc import get_current_device
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments, FinetuningArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments", finetuning_args: "FinetuningArguments"
|
||||
config: "PretrainedConfig",
|
||||
model_name_or_path: str,
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"model_name": model_name_or_path,
|
||||
|
@ -263,6 +263,11 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="mistral3",
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_audio",
|
||||
vision_model_keys=["audio_tower"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user