diff --git a/README.md b/README.md index fdc931b7..639b2041 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | @@ -717,7 +718,7 @@ If you have a project that should be incorporated, please contact via email or c This repository is licensed under the [Apache-2.0 License](LICENSE). -Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## Citation diff --git a/README_zh.md b/README_zh.md index c36cabf1..40e057f5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -191,6 +191,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [PaliGemma](https://huggingface.co/google) | 3B | paligemma | | [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | @@ -717,7 +718,7 @@ run_name: test_run # 可选 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 -使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) +使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan) ## 引用 diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 2b1d9fe5..909f8161 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine): mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) for key, value in mm_inputs.items(): - value = value if isinstance(value, torch.Tensor) else torch.tensor(value) + if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs + value = torch.stack(value) # assume they have same sizes + elif not isinstance(value, torch.Tensor): + value = torch.tensor(value) + gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index bc8ab347..8fa6f0dd 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features: Dict[str, "torch.Tensor"] = super().__call__(features) features.update(mm_inputs) + if isinstance(features.get("pixel_values"), list): # for pixtral inputs + features = features.data # use default_collate() instead of BatchEncoding.to() + return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a2a2b8cc..52c65cb7 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -448,6 +448,70 @@ class PaliGemmaPlugin(BasePlugin): return mm_inputs +class PixtralPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + patch_size = getattr(processor, "patch_size") + image_token = getattr(processor, "image_token") + image_break_token = getattr(processor, "image_break_token") + image_end_token = getattr(processor, "image_end_token") + + num_image_tokens = 0 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, processor) + image_input_sizes = mm_inputs.get("image_sizes", None) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if image_input_sizes is None: + raise ValueError( + "The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER) + ) + + image_size = image_input_sizes[0][num_image_tokens] + height, width = image_size + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size + replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens + replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list + replace_tokens[-1] = image_end_token + replace_str = "".join(replace_tokens) + content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) + num_image_tokens += 1 + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + mm_inputs = self._get_mm_inputs(images, videos, processor) + if mm_inputs.get("pixel_values"): + mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0] + + mm_inputs.pop("image_sizes", None) + return mm_inputs + + class Qwen2vlPlugin(BasePlugin): @override def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": @@ -610,6 +674,7 @@ PLUGINS = { "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, "paligemma": PaliGemmaPlugin, + "pixtral": PixtralPlugin, "qwen2_vl": Qwen2vlPlugin, "video_llava": VideoLlavaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 96738b33..d0da3b30 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -935,6 +935,14 @@ _register_template( ) +_register_template( + name="pixtral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"), +) + + _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 7c58e2c5..4e81ef45 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1178,6 +1178,18 @@ register_model_group( ) +register_model_group( + models={ + "Pixtral-12B-Chat": { + DownloadSource.DEFAULT: "mistral-community/pixtral-12b", + DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", + } + }, + template="pixtral", + vision=True, +) + + register_model_group( models={ "Qwen-1.8B": { diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index ac95eead..bcd21841 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") @@ -113,6 +113,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None: "llava_next", "llava_next_video", "paligemma", + "pixtral", "video_llava", ]: # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) @@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") @@ -186,7 +187,7 @@ def patch_target_modules( """ model_type = getattr(config, "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) @@ -195,5 +196,7 @@ def patch_target_modules( else: if model_type == "qwen2_vl": return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "pixtral": + return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) else: return target_modules diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 75541000..66e9b57c 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: for key in batch_a.keys(): if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) + elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]): + assert len(batch_a[key]) == len(batch_b[key]) + for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]): + assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5) else: assert batch_a[key] == batch_b[key] @@ -179,6 +183,28 @@ def test_paligemma_plugin(): _check_plugin(**check_inputs) +def test_pixtral_plugin(): + tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b") + pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]") + image_slice_height, image_slice_width = 2, 2 + check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} + check_inputs["expected_mm_messages"] = [ + { + key: value.replace( + "", + ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + + "[IMG_END]", + ) + for key, value in message.items() + } + for message in MM_MESSAGES + ] + check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) + check_inputs["expected_mm_inputs"].pop("image_sizes") + check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0] + _check_plugin(**check_inputs) + + def test_qwen2_vl_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")