From 1f4a0b11baed2e443ee9a1f548634e603ff62521 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 18 Feb 2025 02:12:51 +0800 Subject: [PATCH] [data] update vlm args (#6976) Former-commit-id: 3da2cc2710c9b13ab450815a92fff14b03251984 --- examples/train_full/qwen2vl_full_sft.yaml | 4 +-- examples/train_lora/qwen2vl_lora_dpo.yaml | 4 +-- examples/train_lora/qwen2vl_lora_sft.yaml | 4 +-- src/llamafactory/data/mm_plugin.py | 38 +++++++++++------------ src/llamafactory/hparams/model_args.py | 8 ++--- src/llamafactory/model/patcher.py | 8 ++--- 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index 561f2873..bdf28fe9 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_max_resolution: 262144 -video_max_resolution: 16384 +image_max_pixels: 262144 +video_max_pixels: 16384 trust_remote_code: true ### method diff --git a/examples/train_lora/qwen2vl_lora_dpo.yaml b/examples/train_lora/qwen2vl_lora_dpo.yaml index 50de03ef..6fed819e 100644 --- a/examples/train_lora/qwen2vl_lora_dpo.yaml +++ b/examples/train_lora/qwen2vl_lora_dpo.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_max_resolution: 262144 -video_max_resolution: 16384 +image_max_pixels: 262144 +video_max_pixels: 16384 trust_remote_code: true ### method diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml index 78e3c99a..e2c11520 100644 --- a/examples/train_lora/qwen2vl_lora_sft.yaml +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_max_resolution: 262144 -video_max_resolution: 16384 +image_max_pixels: 262144 +video_max_pixels: 16384 trust_remote_code: true ### method diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 2807dcc7..1ace77f8 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -105,18 +105,18 @@ class MMPluginMixin: ) def _preprocess_image( - self, image: "ImageObject", image_max_resolution: int, image_min_resolution: int, **kwargs + self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs ) -> "ImageObject": r""" Pre-processes a single image. """ - if (image.width * image.height) > image_max_resolution: - resize_factor = math.sqrt(image_max_resolution / (image.width * image.height)) + if (image.width * image.height) > image_max_pixels: + resize_factor = math.sqrt(image_max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) image = image.resize((width, height), resample=Image.Resampling.NEAREST) - if (image.width * image.height) < image_min_resolution: - resize_factor = math.sqrt(image_min_resolution / (image.width * image.height)) + if (image.width * image.height) < image_min_pixels: + resize_factor = math.sqrt(image_min_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) image = image.resize((width, height), resample=Image.Resampling.NEAREST) @@ -225,16 +225,16 @@ class MMPluginMixin: if len(images) != 0: images = self._regularize_images( images, - image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), - image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), ) mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: videos = self._regularize_videos( videos, - image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), - image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) @@ -616,8 +616,8 @@ class MiniCPMVPlugin(BasePlugin): if len(images) != 0: images = self._regularize_images( images, - image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), - image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), ) if "valid_image_nums_ls" in kwargs: valid_image_nums_ls = kwargs["valid_image_nums_ls"] @@ -637,8 +637,8 @@ class MiniCPMVPlugin(BasePlugin): if len(videos) != 0: videos = self._regularize_videos( videos, - image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), - image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) @@ -788,8 +788,8 @@ class MllamaPlugin(BasePlugin): image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") images = self._regularize_images( images, - image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), - image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), ) batch_images = [] for image_length in imglens: @@ -1082,16 +1082,16 @@ class Qwen2vlPlugin(BasePlugin): if len(images) != 0: images = self._regularize_images( images, - image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), - image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), ) mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: videos, fps_per_video = self._regularize_videos( videos, - image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), - image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index f742a646..fba55832 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -58,19 +58,19 @@ class ProcessorArguments: Arguments pertaining to the image processor. """ - image_max_resolution: int = field( + image_max_pixels: int = field( default=768 * 768, metadata={"help": "The maximum number of pixels of image inputs."}, ) - image_min_resolution: int = field( + image_min_pixels: int = field( default=32 * 32, metadata={"help": "The minimum number of pixels of image inputs."}, ) - video_max_resolution: int = field( + video_max_pixels: int = field( default=256 * 256, metadata={"help": "The maximum number of pixels of video inputs."}, ) - video_min_resolution: int = field( + video_min_pixels: int = field( default=16 * 16, metadata={"help": "The minimum number of pixels of video inputs."}, ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 11404e43..126abe5d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -80,10 +80,10 @@ def patch_processor( if getattr(config, "vision_config", None) is not None: # visual models setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "patch_size", get_patch_size(config, processor)) - setattr(processor, "image_max_resolution", model_args.image_max_resolution) - setattr(processor, "image_min_resolution", model_args.image_min_resolution) - setattr(processor, "video_max_resolution", model_args.video_max_resolution) - setattr(processor, "video_min_resolution", model_args.video_min_resolution) + setattr(processor, "image_max_pixels", model_args.image_max_pixels) + setattr(processor, "image_min_pixels", model_args.image_min_pixels) + setattr(processor, "video_max_pixels", model_args.video_max_pixels) + setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))