mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] gemma3 plugin pan and scan (#7294)
* gemma3 pan and scan * add test case * fix test
This commit is contained in:
		
							parent
							
								
									0be0d7796a
								
							
						
					
					
						commit
						93e6184cbe
					
				@ -290,7 +290,18 @@ class MMPluginMixin:
 | 
			
		||||
            if imglens is not None:
 | 
			
		||||
                images = _make_batched_images(images, imglens)
 | 
			
		||||
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
            image_processor_kwargs = {}
 | 
			
		||||
            if getattr(processor, "image_do_pan_and_scan", False):  # gemma3 image processor
 | 
			
		||||
                image_processor_kwargs.update(
 | 
			
		||||
                    {
 | 
			
		||||
                        "do_pan_and_scan": True,
 | 
			
		||||
                        "pan_and_scan_min_crop_size": 256,
 | 
			
		||||
                        "pan_and_scan_max_num_crops": 4,
 | 
			
		||||
                        "pan_and_scan_min_ratio_to_activate": 1.2,
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
@ -401,10 +412,23 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
        boi_token: str = getattr(processor, "boi_token")
 | 
			
		||||
        full_image_sequence: str = getattr(processor, "full_image_sequence")
 | 
			
		||||
        image_str = full_image_sequence if self.expand_mm_tokens else boi_token
 | 
			
		||||
 | 
			
		||||
        do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
 | 
			
		||||
        if do_pan_and_scan:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
 | 
			
		||||
                if do_pan_and_scan:
 | 
			
		||||
                    image_placeholder_str = (
 | 
			
		||||
                        "Here is the original image {{image}} and here are some crops to help you see better "
 | 
			
		||||
                        + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    image_placeholder_str = "{{image}}"
 | 
			
		||||
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", image_str)
 | 
			
		||||
 | 
			
		||||
@ -1263,6 +1263,7 @@ register_template(
 | 
			
		||||
    format_user=StringFormatter(slots=["{{content}}\n"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1277,6 +1278,7 @@ register_template(
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<end_of_turn>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -218,6 +218,10 @@ class ProcessorArguments:
 | 
			
		||||
        default=32 * 32,
 | 
			
		||||
        metadata={"help": "The minimum number of pixels of image inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    image_do_pan_and_scan: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Use pan and scan to process image for gemma3."},
 | 
			
		||||
    )
 | 
			
		||||
    video_max_pixels: int = field(
 | 
			
		||||
        default=256 * 256,
 | 
			
		||||
        metadata={"help": "The maximum number of pixels of video inputs."},
 | 
			
		||||
@ -235,6 +239,13 @@ class ProcessorArguments:
 | 
			
		||||
        metadata={"help": "The maximum number of sampled frames for video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.image_max_pixels < self.image_min_pixels:
 | 
			
		||||
            raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
 | 
			
		||||
 | 
			
		||||
        if self.video_max_pixels < self.video_min_pixels:
 | 
			
		||||
            raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ExportArguments:
 | 
			
		||||
@ -342,6 +353,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        BaseModelArguments.__post_init__(self)
 | 
			
		||||
        ProcessorArguments.__post_init__(self)
 | 
			
		||||
        ExportArguments.__post_init__(self)
 | 
			
		||||
        VllmArguments.__post_init__(self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,8 +50,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
 | 
			
		||||
    if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
 | 
			
		||||
        tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
 | 
			
		||||
        tokenizer.model_max_length = model_args.model_max_length
 | 
			
		||||
    if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
 | 
			
		||||
        tokenizer.model_max_length = model_args.model_max_length  # enlarge the tokenizer max length
 | 
			
		||||
 | 
			
		||||
    if model_args.new_special_tokens is not None:
 | 
			
		||||
        num_added_tokens = tokenizer.add_special_tokens(
 | 
			
		||||
@ -72,6 +72,7 @@ def patch_processor(
 | 
			
		||||
    setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
    setattr(processor, "image_max_pixels", model_args.image_max_pixels)
 | 
			
		||||
    setattr(processor, "image_min_pixels", model_args.image_min_pixels)
 | 
			
		||||
    setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
 | 
			
		||||
    setattr(processor, "video_max_pixels", model_args.video_max_pixels)
 | 
			
		||||
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
from llamafactory.data.mm_plugin import get_mm_plugin
 | 
			
		||||
from llamafactory.extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from llamafactory.hparams import get_infer_args
 | 
			
		||||
from llamafactory.model import load_tokenizer
 | 
			
		||||
 | 
			
		||||
@ -135,6 +136,27 @@ def test_base_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
 | 
			
		||||
def test_gemma3_plugin():
 | 
			
		||||
    image_seqlen = 256
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
 | 
			
		||||
    gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
 | 
			
		||||
    image_tokens_expanded = "<image_soft_token>" * image_seqlen
 | 
			
		||||
    check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {
 | 
			
		||||
            key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
 | 
			
		||||
            for key, value in message.items()
 | 
			
		||||
        }
 | 
			
		||||
        for message in MM_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
 | 
			
		||||
    check_inputs["expected_mm_inputs"].pop("num_crops")
 | 
			
		||||
    check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
 | 
			
		||||
    check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llava_plugin():
 | 
			
		||||
    image_seqlen = 576
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user