mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] gemma3 plugin pan and scan (#7294)
* gemma3 pan and scan * add test case * fix test
This commit is contained in:
parent
3dff4ecca8
commit
ef5f1c1def
@ -290,7 +290,18 @@ class MMPluginMixin:
|
||||
if imglens is not None:
|
||||
images = _make_batched_images(images, imglens)
|
||||
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
image_processor_kwargs = {}
|
||||
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
|
||||
image_processor_kwargs.update(
|
||||
{
|
||||
"do_pan_and_scan": True,
|
||||
"pan_and_scan_min_crop_size": 256,
|
||||
"pan_and_scan_max_num_crops": 4,
|
||||
"pan_and_scan_min_ratio_to_activate": 1.2,
|
||||
}
|
||||
)
|
||||
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
@ -401,10 +412,23 @@ class Gemma3Plugin(BasePlugin):
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
||||
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
||||
|
||||
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
|
||||
if do_pan_and_scan:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
if do_pan_and_scan:
|
||||
image_placeholder_str = (
|
||||
"Here is the original image {{image}} and here are some crops to help you see better "
|
||||
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
|
||||
)
|
||||
else:
|
||||
image_placeholder_str = "{{image}}"
|
||||
|
||||
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content.replace("{{image}}", image_str)
|
||||
|
@ -1263,6 +1263,7 @@ register_template(
|
||||
format_user=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
@ -1277,6 +1278,7 @@ register_template(
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<end_of_turn>"],
|
||||
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
|
||||
template_class=Llama2Template,
|
||||
)
|
||||
|
||||
|
||||
|
@ -218,6 +218,10 @@ class ProcessorArguments:
|
||||
default=32 * 32,
|
||||
metadata={"help": "The minimum number of pixels of image inputs."},
|
||||
)
|
||||
image_do_pan_and_scan: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use pan and scan to process image for gemma3."},
|
||||
)
|
||||
video_max_pixels: int = field(
|
||||
default=256 * 256,
|
||||
metadata={"help": "The maximum number of pixels of video inputs."},
|
||||
@ -235,6 +239,13 @@ class ProcessorArguments:
|
||||
metadata={"help": "The maximum number of sampled frames for video inputs."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.image_max_pixels < self.image_min_pixels:
|
||||
raise ValueError("`image_max_pixels` cannot be smaller than `image_min_pixels`.")
|
||||
|
||||
if self.video_max_pixels < self.video_min_pixels:
|
||||
raise ValueError("`video_max_pixels` cannot be smaller than `video_min_pixels`.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportArguments:
|
||||
@ -342,6 +353,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
|
||||
def __post_init__(self):
|
||||
BaseModelArguments.__post_init__(self)
|
||||
ProcessorArguments.__post_init__(self)
|
||||
ExportArguments.__post_init__(self)
|
||||
VllmArguments.__post_init__(self)
|
||||
|
||||
|
@ -50,8 +50,8 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
@ -72,6 +72,7 @@ def patch_processor(
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
|
||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@ -135,6 +136,27 @@ def test_base_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HF_TOKEN or not is_transformers_version_greater_than("4.50.0"), reason="Gated model.")
|
||||
def test_gemma3_plugin():
|
||||
image_seqlen = 256
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
|
||||
gemma3_plugin = get_mm_plugin(name="gemma3", image_token="<image_soft_token>")
|
||||
image_tokens_expanded = "<image_soft_token>" * image_seqlen
|
||||
check_inputs = {"plugin": gemma3_plugin, **tokenizer_module}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{
|
||||
key: value.replace("<image>", f"\n\n<start_of_image>{image_tokens_expanded}<end_of_image>\n\n")
|
||||
for key, value in message.items()
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||
check_inputs["expected_mm_inputs"].pop("num_crops")
|
||||
check_inputs["expected_mm_inputs"]["token_type_ids"] = [[0] * 1024]
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[0] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
|
Loading…
x
Reference in New Issue
Block a user