diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e074d021..0a63800a 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1042,15 +1042,15 @@ class Qwen2VLPlugin(BasePlugin): image = super()._preprocess_image(image, **kwargs) if min(image.width, image.height) < 28: width, height = max(image.width, 28), max(image.height, 28) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.width / image.height > 200: width, height = image.height * 180, image.height - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.height / image.width > 200: width, height = image.width, image.width * 180 - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) return image diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 58393563..a789cf7a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1268,9 +1268,17 @@ register_template( ) -# copied from gemma template register_template( name="paligemma", + format_user=StringFormatter(slots=["{{content}}\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + mm_plugin=get_mm_plugin(name="paligemma", image_token=""), +) + + +# copied from gemma template +register_template( + name="paligemma_chat", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), format_observation=StringFormatter(