From 2b21c749c1add9b5568032e726f7559ab7a3fe14 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 6 Mar 2025 15:18:32 +0800 Subject: [PATCH] [data] fix mm template (#7181) Former-commit-id: be66df1f0211cd2d90eac3ab407dced653c9e443 --- src/llamafactory/data/mm_plugin.py | 6 +++--- src/llamafactory/data/template.py | 10 +++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e074d021..0a63800a 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1042,15 +1042,15 @@ class Qwen2VLPlugin(BasePlugin): image = super()._preprocess_image(image, **kwargs) if min(image.width, image.height) < 28: width, height = max(image.width, 28), max(image.height, 28) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.width / image.height > 200: width, height = image.height * 180, image.height - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.height / image.width > 200: width, height = image.width, image.width * 180 - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) return image diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 58393563..a789cf7a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1268,9 +1268,17 @@ register_template( ) -# copied from gemma template register_template( name="paligemma", + format_user=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="paligemma", image_token=""), +) + + +# copied from gemma template +register_template( + name="paligemma_chat", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), format_observation=StringFormatter(