diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 133984cd..bc6e1afe 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -125,12 +125,12 @@ class MMPluginMixin: if (image.width * image.height) > image_max_pixels: resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if (image.width * image.height) < image_min_pixels: resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) - image = image.resize((width, height), resample=Image.Resampling.NEAREST) + image = image.resize((width, height)) if image.mode != "RGB": image = image.convert("RGB")