From b76116bb6cbdb05f3203a3363d1e04037d701dc3 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Thu, 26 Sep 2024 17:14:51 +0800 Subject: [PATCH] add pixtral template Former-commit-id: 7b3336dd97e06a11ec52433ef36980aefdbb45ba --- src/llamafactory/data/mm_plugin.py | 64 ++++++++++++++++++++++++++++-- src/llamafactory/model/loader.py | 37 ----------------- 2 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index ea0f2185..0e59ec0b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from av.stream import Stream from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor + from transformers.processing_utils import _validate_images_text_input_order, ProcessingKwargs class EncodedImage(TypedDict): path: Optional[str] @@ -324,11 +325,65 @@ class PaliGemmaPlugin(BasePlugin): return mm_inputs class PixtralPlugin(BasePlugin): - #TODO preprocess according to Pixtral hf - from transformers import LlavaForConditionalGeneration @override - def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": - pass + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + patch_size = processor.patch_size + image_token = processor.image_token + image_break_token = processor.image_break_token + image_end_token = processor.image_end_token + + self._validate_input(images, videos) + num_image_tokens = 0 + image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"] + messages = deepcopy(messages) + print(image_input_sizes[0], messages) + for message in messages: + content = message["content"] + img_id = 0 + while IMAGE_PLACEHOLDER in content: + # only support one image for one time? + image_size = image_input_sizes[0][0] + height, width = image_size + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [ + [image_token] * num_width_tokens + [image_break_token] + ] * num_height_tokens + # Flatten list + replace_tokens = [item for sublist in replace_tokens for item in sublist] + replace_tokens[-1] = image_end_token + replace_str = "".join(replace_tokens) + content.replace(IMAGE_PLACEHOLDER, replace_str, 1) + + img_id += 1 + num_image_tokens += 1 + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + + self._validate_input(images, videos) + return self._get_mm_inputs(images, videos, processor) class Qwen2vlPlugin(BasePlugin): @override @@ -428,6 +483,7 @@ PLUGINS = { "llava": LlavaPlugin, "paligemma": PaliGemmaPlugin, "qwen2_vl": Qwen2vlPlugin, + "pixtral": PixtralPlugin, } diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index bc4e101c..96d61645 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -119,43 +119,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": Loads model config. """ init_kwargs = _get_init_kwargs(model_args) - if "pixtral" in model_args.model_name_or_path: - from transformers import PretrainedConfig - - class PixtralVisionConfig(PretrainedConfig): - model_type = "pixtral" - - def __init__( - self, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=24, - num_attention_heads=16, - num_channels=3, - image_size=1024, - patch_size=16, - hidden_act="gelu", - attention_dropout=0.0, - rope_theta=10000.0, - tie_word_embeddings=False, - **kwargs, - ): - super().__init__(**kwargs) - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_channels = num_channels - self.patch_size = patch_size - self.image_size = image_size - self.attention_dropout = attention_dropout - self.hidden_act = hidden_act - self.rope_theta = rope_theta - self.tie_word_embeddings = tie_word_embeddings - self.head_dim = hidden_size // num_attention_heads - - return PixtralVisionConfig() return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)