mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 06:42:52 +08:00
add pixtral template
Former-commit-id: 86f5a9be548ef02ce334bba35a529c70e8b3ad7f
This commit is contained in:
parent
944ae8780c
commit
c436d6ea0b
@ -323,6 +323,12 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
class PixtralPlugin(BasePlugin):
|
||||||
|
#TODO preprocess according to Pixtral hf
|
||||||
|
from transformers import LlavaForConditionalGeneration
|
||||||
|
@override
|
||||||
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
|
pass
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
|
@ -821,6 +821,13 @@ _register_template(
|
|||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="pixtral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
|
@ -894,6 +894,16 @@ register_model_group(
|
|||||||
template="mistral",
|
template="mistral",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Pixtral-12B-2409": {
|
||||||
|
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="mistral"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
|
@ -119,6 +119,44 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
|||||||
Loads model config.
|
Loads model config.
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
|
if "pixtral" in model_args.model_name_or_path:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
class PixtralVisionConfig(PretrainedConfig):
|
||||||
|
model_type = "pixtral"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=4096,
|
||||||
|
num_hidden_layers=24,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=1024,
|
||||||
|
patch_size=16,
|
||||||
|
hidden_act="gelu",
|
||||||
|
attention_dropout=0.0,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
|
|
||||||
|
return PixtralVisionConfig()
|
||||||
|
|
||||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user