From fd79cf8551418954f199ec266f4385f8b1e5f894 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Sat, 28 Sep 2024 22:50:53 +0800 Subject: [PATCH] tiny fix Former-commit-id: 3d3cc6705d4575f7f20bf4da2b7dab60b337006b --- README.md | 1 + README_zh.md | 1 + src/llamafactory/data/mm_plugin.py | 21 ++++++-------------- src/llamafactory/extras/constants.py | 4 ++-- src/llamafactory/model/model_utils/visual.py | 13 ++++++++++-- 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 92bbcc88..cf37565b 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Pixtral](https://huggingface.co/mistralai/Pixtral-12B-2409) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/README_zh.md b/README_zh.md index 0b02f35f..4b3b53de 100644 --- a/README_zh.md +++ b/README_zh.md @@ -184,6 +184,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Pixtral](https://huggingface.co/mistralai/Pixtral-12B-2409) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6716527c..2b85c2c5 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -325,14 +325,6 @@ class PaliGemmaPlugin(BasePlugin): return mm_inputs class PixtralPlugin(BasePlugin): - # @override - # def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": - # image = super()._preprocess_image(image, **kwargs) - # UP_SIZE = (512,512) - # image = image.resize(UP_SIZE, resample=Image.NEAREST) - - # return image - @override def process_messages( self, @@ -396,16 +388,15 @@ class PixtralPlugin(BasePlugin): seqlens: Sequence[int], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos) mm_inputs = self._get_mm_inputs(images, videos, processor) - if mm_inputs.get('image_sizes'): - del mm_inputs['image_sizes'] - # TODO fix this type error - # if isinstance(mm_inputs.get("pixel_values"), list): #List[List[torch.tensor]] -> [B C W H] - # recommend for batch==1 for one gpu or it will rise the error of BatchEncoding. + if mm_inputs.get("image_sizes"): + mm_inputs.pop("image_sizes") + + if isinstance(mm_inputs.get("pixel_values"), list) and len(mm_inputs.get("pixel_values")[0]) >= 2: + raise ValueError("Now it only supports batchsize=1 on per gpu due to `List[tensor]` can not pack into BachEncoding") + mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0) - # mm_inputs["pixel_values"] = mm_inputs.get("pixel_values") return mm_inputs diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e3f6a99d..3de1c7a2 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1060,8 +1060,8 @@ register_model_group( register_model_group( models={ "Pixtral-12B-2409": { - DownloadSource.DEFAULT: "mistral-community/pixtral-12b", - DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", + DownloadSource.DEFAULT: "mistralai/Pixtral-12B-2409", + DownloadSource.MODELSCOPE: "LLM-Research/Pixtral-12B-2409", } }, template="mistral", diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 23f880a6..107590bd 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -96,6 +96,9 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") + # TODO check it + elif model_type == "pixtral": + mm_projector: "torch.nn.Module" = getattr(model, "vision_language_adapte") else: return @@ -122,9 +125,11 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "paligemma"]: + if model_type in ["llava", "paligemma", "pixtral"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") + #TODO check it + forbidden_modules.add("vision_encoder") if finetuning_args.train_mm_proj_only: forbidden_modules.add("language_model") @@ -150,7 +155,7 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: image_seqlen += 1 elif model_type == "paligemma": image_seqlen = config.vision_config.num_image_tokens - elif model_type == "qwen2_vl": # variable length + elif model_type in ["qwen2_vl", "pixtral"]: # variable length image_seqlen = -1 return image_seqlen @@ -168,10 +173,14 @@ def patch_target_modules( return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "pixtral": + return "^(?!.*vision_encoder).*(?:{}).*".format("|".join(target_modules)) else: return target_modules else: if model_type == "qwen2_vl": return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "pixtral": + return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) else: return target_modules