From c436d6ea0b6cd51d1b6a81894c8f11d4ea05b366 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Thu, 26 Sep 2024 12:11:58 +0800 Subject: [PATCH] add pixtral template Former-commit-id: 86f5a9be548ef02ce334bba35a529c70e8b3ad7f --- src/llamafactory/data/mm_plugin.py | 6 +++++ src/llamafactory/data/template.py | 7 +++++ src/llamafactory/extras/constants.py | 10 ++++++++ src/llamafactory/model/loader.py | 38 ++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e22e2760..ea0f2185 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -323,6 +323,12 @@ class PaliGemmaPlugin(BasePlugin): mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs +class PixtralPlugin(BasePlugin): + #TODO preprocess according to Pixtral hf + from transformers import LlavaForConditionalGeneration + @override + def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + pass class Qwen2vlPlugin(BasePlugin): @override diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 54da4757..9b844d88 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -821,6 +821,13 @@ _register_template( replace_eos=True, ) +_register_template( + name="pixtral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]") +) + _register_template( name="qwen", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 8d8d4424..e88f0da7 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -894,6 +894,16 @@ register_model_group( template="mistral", ) +register_model_group( + models={ + "Pixtral-12B-2409": { + DownloadSource.DEFAULT: "mistral-community/pixtral-12b", + DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", + } + }, + template="mistral" +) + register_model_group( models={ diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 030ce90f..bc4e101c 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -119,6 +119,44 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": Loads model config. """ init_kwargs = _get_init_kwargs(model_args) + if "pixtral" in model_args.model_name_or_path: + from transformers import PretrainedConfig + + class PixtralVisionConfig(PretrainedConfig): + model_type = "pixtral" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + attention_dropout=0.0, + rope_theta=10000.0, + tie_word_embeddings=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.rope_theta = rope_theta + self.tie_word_embeddings = tie_word_embeddings + self.head_dim = hidden_size // num_attention_heads + + return PixtralVisionConfig() + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)