mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #5581 from Kuangdd01/pixtral-patch
[WIP] Support Pixtral-12B Former-commit-id: 9009a467e621a17ad9fa25bb30fb9ac9ee15df97
This commit is contained in:
commit
5142faca8f
@ -190,6 +190,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
@ -717,7 +718,7 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
||||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
@ -191,6 +191,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||||
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||||
@ -717,7 +718,7 @@ run_name: test_run # 可选
|
|||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
||||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
@ -166,7 +166,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||||
|
value = torch.stack(value) # assume they have same sizes
|
||||||
|
elif not isinstance(value, torch.Tensor):
|
||||||
|
value = torch.tensor(value)
|
||||||
|
|
||||||
gen_kwargs[key] = value.to(model.device)
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
features.update(mm_inputs)
|
features.update(mm_inputs)
|
||||||
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ -448,6 +448,70 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class PixtralPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
patch_size = getattr(processor, "patch_size")
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
image_break_token = getattr(processor, "image_break_token")
|
||||||
|
image_end_token = getattr(processor, "image_end_token")
|
||||||
|
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if image_input_sizes is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
|
||||||
|
)
|
||||||
|
|
||||||
|
image_size = image_input_sizes[0][num_image_tokens]
|
||||||
|
height, width = image_size
|
||||||
|
num_height_tokens = height // patch_size
|
||||||
|
num_width_tokens = width // patch_size
|
||||||
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||||
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||||
|
replace_tokens[-1] = image_end_token
|
||||||
|
replace_str = "".join(replace_tokens)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
seqlens: Sequence[int],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if mm_inputs.get("pixel_values"):
|
||||||
|
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||||
|
|
||||||
|
mm_inputs.pop("image_sizes", None)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
@ -610,6 +674,7 @@ PLUGINS = {
|
|||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
}
|
}
|
||||||
|
@ -935,6 +935,14 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="pixtral",
|
||||||
|
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
@ -1178,6 +1178,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Pixtral-12B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="pixtral",
|
||||||
|
vision=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen-1.8B": {
|
"Qwen-1.8B": {
|
||||||
|
@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
|
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||||
@ -113,6 +113,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
"llava_next",
|
"llava_next",
|
||||||
"llava_next_video",
|
"llava_next_video",
|
||||||
"paligemma",
|
"paligemma",
|
||||||
|
"pixtral",
|
||||||
"video_llava",
|
"video_llava",
|
||||||
]: # required for ds zero3 and valuehead models
|
]: # required for ds zero3 and valuehead models
|
||||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
@ -128,7 +129,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
@ -186,7 +187,7 @@ def patch_target_modules(
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]:
|
||||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
@ -195,5 +196,7 @@ def patch_target_modules(
|
|||||||
else:
|
else:
|
||||||
if model_type == "qwen2_vl":
|
if model_type == "qwen2_vl":
|
||||||
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
elif model_type == "pixtral":
|
||||||
|
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
|
||||||
else:
|
else:
|
||||||
return target_modules
|
return target_modules
|
||||||
|
@ -74,6 +74,10 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
|||||||
for key in batch_a.keys():
|
for key in batch_a.keys():
|
||||||
if isinstance(batch_a[key], torch.Tensor):
|
if isinstance(batch_a[key], torch.Tensor):
|
||||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||||
|
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
|
||||||
|
assert len(batch_a[key]) == len(batch_b[key])
|
||||||
|
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||||
|
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||||
else:
|
else:
|
||||||
assert batch_a[key] == batch_b[key]
|
assert batch_a[key] == batch_b[key]
|
||||||
|
|
||||||
@ -179,6 +183,28 @@ def test_paligemma_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pixtral_plugin():
|
||||||
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||||
|
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||||
|
image_slice_height, image_slice_width = 2, 2
|
||||||
|
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{
|
||||||
|
key: value.replace(
|
||||||
|
"<image>",
|
||||||
|
("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0]
|
||||||
|
+ "[IMG_END]",
|
||||||
|
)
|
||||||
|
for key, value in message.items()
|
||||||
|
}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
|
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||||
|
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0]
|
||||||
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_qwen2_vl_plugin():
|
def test_qwen2_vl_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user