diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index b7e6852d..5f0ceaae 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -290,7 +290,18 @@ class MMPluginMixin: if imglens is not None: images = _make_batched_images(images, imglens) - mm_inputs.update(image_processor(images, return_tensors="pt")) + image_processor_kwargs = {} + if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor + image_processor_kwargs.update( + { + "do_pan_and_scan": True, + "pan_and_scan_min_crop_size": 256, + "pan_and_scan_max_num_crops": 4, + "pan_and_scan_min_ratio_to_activate": 1.2, + } + ) + + mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs)) if len(videos) != 0: video_processor: BaseImageProcessor = getattr( @@ -401,10 +412,23 @@ class Gemma3Plugin(BasePlugin): boi_token: str = getattr(processor, "boi_token") full_image_sequence: str = getattr(processor, "full_image_sequence") image_str = full_image_sequence if self.expand_mm_tokens else boi_token + + do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False) + if do_pan_and_scan: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + if do_pan_and_scan: + image_placeholder_str = ( + "Here is the original image {{image}} and here are some crops to help you see better " + + " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens]) + ) + else: + image_placeholder_str = "{{image}}" + + content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1) num_image_tokens += 1 message["content"] = content.replace("{{image}}", image_str) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index cc4a6dcb..1d8dbee7 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1263,6 +1263,7 @@ register_template( format_user=StringFormatter(slots=["{{content}}\n"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), mm_plugin=get_mm_plugin(name="paligemma", image_token=""), + template_class=Llama2Template, ) @@ -1277,6 +1278,7 @@ register_template( format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], mm_plugin=get_mm_plugin(name="paligemma", image_token=""), + template_class=Llama2Template, ) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 3b6a1eea..4d7693e8 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -218,6 +218,10 @@ class ProcessorArguments: default=32 * 32, metadata={"help": "The minimum number of pixels of image inputs."}, ) + image_do_pan_and_scan: bool = field( + default=False, + metadata={"help": "Use pan and scan to process image for gemma3."}, + ) video_max_pixels: int = field( default=256 * 256, metadata={"help": "The maximum number of pixels of video inputs."}, @@ -235,6 +239,13 @@ class ProcessorArguments: metadata={"help": "The maximum number of sampled frames for video inputs."}, ) + def __post_init__(self): + if self.image_max_pixels < self.image_min_pixels: + raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.") + + if self.video_max_pixels < self.video_min_pixels: + raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.") + @dataclass class ExportArguments: @@ -342,6 +353,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz def __post_init__(self): BaseModelArguments.__post_init__(self) + ProcessorArguments.__post_init__(self) ExportArguments.__post_init__(self) VllmArguments.__post_init__(self) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 8997757d..863197e5 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -50,8 +50,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) - if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length: - tokenizer.model_max_length = model_args.model_max_length + if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length: + tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length if model_args.new_special_tokens is not None: num_added_tokens = tokenizer.add_special_tokens( @@ -72,6 +72,7 @@ def patch_processor( setattr(processor, "tokenizer", tokenizer) setattr(processor, "image_max_pixels", model_args.image_max_pixels) setattr(processor, "image_min_pixels", model_args.image_min_pixels) + setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan) setattr(processor, "video_max_pixels", model_args.video_max_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index e064d195..e7c58b55 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -20,6 +20,7 @@ import torch from PIL import Image from llamafactory.data.mm_plugin import get_mm_plugin +from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer @@ -135,6 +136,27 @@ def test_base_plugin(): _check_plugin(**check_inputs) +@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.") +def test_gemma3_plugin(): + image_seqlen = 256 + tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it") + gemma3_plugin = get_mm_plugin(name="gemma3", image_token="") + image_tokens_expanded = "" * image_seqlen + check_inputs = {"plugin": gemma3_plugin, **tokenizer_module} + check_inputs["expected_mm_messages"] = [ + { + key: value.replace("", f"\n\n{image_tokens_expanded}\n\n") + for key, value in message.items() + } + for message in MM_MESSAGES + ] + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) + check_inputs["expected_mm_inputs"].pop("num_crops") + check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024] + check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]} + _check_plugin(**check_inputs) + + def test_llava_plugin(): image_seqlen = 576 tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")