mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
add pixtral template
Former-commit-id: 7b3336dd97e06a11ec52433ef36980aefdbb45ba
This commit is contained in:
parent
35e44143fd
commit
b76116bb6c
@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|||||||
from av.stream import Stream
|
from av.stream import Stream
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
from transformers.processing_utils import _validate_images_text_input_order, ProcessingKwargs
|
||||||
|
|
||||||
class EncodedImage(TypedDict):
|
class EncodedImage(TypedDict):
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
@ -324,11 +325,65 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
class PixtralPlugin(BasePlugin):
|
class PixtralPlugin(BasePlugin):
|
||||||
#TODO preprocess according to Pixtral hf
|
|
||||||
from transformers import LlavaForConditionalGeneration
|
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def process_messages(
|
||||||
pass
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
patch_size = processor.patch_size
|
||||||
|
image_token = processor.image_token
|
||||||
|
image_break_token = processor.image_break_token
|
||||||
|
image_end_token = processor.image_end_token
|
||||||
|
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"]
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
print(image_input_sizes[0], messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
img_id = 0
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
# only support one image for one time?
|
||||||
|
image_size = image_input_sizes[0][0]
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // patch_size
|
||||||
|
num_width_tokens = width // patch_size
|
||||||
|
replace_tokens = [
|
||||||
|
[image_token] * num_width_tokens + [image_break_token]
|
||||||
|
] * num_height_tokens
|
||||||
|
# Flatten list
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||||
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
|
||||||
|
img_id += 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -428,6 +483,7 @@ PLUGINS = {
|
|||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
|
"pixtral": PixtralPlugin,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -119,43 +119,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
|||||||
Loads model config.
|
Loads model config.
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
if "pixtral" in model_args.model_name_or_path:
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
|
|
||||||
class PixtralVisionConfig(PretrainedConfig):
|
|
||||||
model_type = "pixtral"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size=1024,
|
|
||||||
intermediate_size=4096,
|
|
||||||
num_hidden_layers=24,
|
|
||||||
num_attention_heads=16,
|
|
||||||
num_channels=3,
|
|
||||||
image_size=1024,
|
|
||||||
patch_size=16,
|
|
||||||
hidden_act="gelu",
|
|
||||||
attention_dropout=0.0,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
tie_word_embeddings=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_channels = num_channels
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.image_size = image_size
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
self.head_dim = hidden_size // num_attention_heads
|
|
||||||
|
|
||||||
return PixtralVisionConfig()
|
|
||||||
|
|
||||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user