From b1d31ff0f91c824565c55a19932a8780a574fc0b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 18 Feb 2025 01:40:46 +0800 Subject: [PATCH] [data] add min resolution option (#6975) Former-commit-id: 7faecc0301709326efa21e7a3fdb75fe0a9635c2 --- .env.local | 1 + examples/train_full/qwen2vl_full_sft.yaml | 4 +-- examples/train_lora/llava1_5_lora_sft.yaml | 2 -- examples/train_lora/qwen2vl_lora_dpo.yaml | 4 +-- examples/train_lora/qwen2vl_lora_sft.yaml | 4 +-- src/llamafactory/api/chat.py | 11 ++++-- src/llamafactory/data/mm_plugin.py | 39 ++++++++++++++++------ src/llamafactory/hparams/model_args.py | 12 +++++-- src/llamafactory/model/patcher.py | 6 ++-- 9 files changed, 59 insertions(+), 24 deletions(-) diff --git a/.env.local b/.env.local index aa43a4d7..38a5503a 100644 --- a/.env.local +++ b/.env.local @@ -4,6 +4,7 @@ API_HOST= API_PORT= API_KEY= API_MODEL_NAME= +API_VERBOSE= FASTAPI_ROOT_PATH= MAX_CONCURRENT= # general diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index e1331e8e..561f2873 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_resolution: 262144 -video_resolution: 16384 +image_max_resolution: 262144 +video_max_resolution: 16384 trust_remote_code: true ### method diff --git a/examples/train_lora/llava1_5_lora_sft.yaml b/examples/train_lora/llava1_5_lora_sft.yaml index ad8060d2..24d09d91 100644 --- a/examples/train_lora/llava1_5_lora_sft.yaml +++ b/examples/train_lora/llava1_5_lora_sft.yaml @@ -1,7 +1,5 @@ ### model model_name_or_path: llava-hf/llava-1.5-7b-hf -image_resolution: 262144 -video_resolution: 16384 trust_remote_code: true ### method diff --git a/examples/train_lora/qwen2vl_lora_dpo.yaml b/examples/train_lora/qwen2vl_lora_dpo.yaml index 486c5bef..50de03ef 100644 --- a/examples/train_lora/qwen2vl_lora_dpo.yaml +++ b/examples/train_lora/qwen2vl_lora_dpo.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_resolution: 262144 -video_resolution: 16384 +image_max_resolution: 262144 +video_max_resolution: 16384 trust_remote_code: true ### method diff --git a/examples/train_lora/qwen2vl_lora_sft.yaml b/examples/train_lora/qwen2vl_lora_sft.yaml index b40ccf9b..78e3c99a 100644 --- a/examples/train_lora/qwen2vl_lora_sft.yaml +++ b/examples/train_lora/qwen2vl_lora_sft.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: Qwen/Qwen2-VL-7B-Instruct -image_resolution: 262144 -video_resolution: 16384 +image_max_resolution: 262144 +video_max_resolution: 16384 trust_remote_code: true ### method diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 6959f4d8..a0edd8e0 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole from ..extras import logging +from ..extras.constants import IMAGE_PLACEHOLDER +from ..extras.misc import is_env_enabled from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify from .protocol import ( @@ -70,7 +72,8 @@ ROLE_MAPPING = { def _process_request( request: "ChatCompletionRequest", ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: - logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") + if is_env_enabled("API_VERBOSE", "1"): + logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") @@ -99,10 +102,12 @@ def _process_request( content = json.dumps(tool_calls, ensure_ascii=False) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) elif isinstance(message.content, list): + text_content = "" for input_item in message.content: if input_item.type == "text": - input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) + text_content += input_item.text else: + text_content += IMAGE_PLACEHOLDER image_url = input_item.image_url.url if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) @@ -112,6 +117,8 @@ def _process_request( image_stream = requests.get(image_url, stream=True).raw images.append(Image.open(image_stream).convert("RGB")) + + input_messages.append({"role": ROLE_MAPPING[message.role], "content": text_content}) else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 781130c4..2807dcc7 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -104,12 +104,19 @@ class MMPluginMixin: "This model does not support audio input. Please check whether the correct `template` is used." ) - def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject": + def _preprocess_image( + self, image: "ImageObject", image_max_resolution: int, image_min_resolution: int, **kwargs + ) -> "ImageObject": r""" Pre-processes a single image. """ - if (image.width * image.height) > image_resolution: - resize_factor = math.sqrt(image_resolution / (image.width * image.height)) + if (image.width * image.height) > image_max_resolution: + resize_factor = math.sqrt(image_max_resolution / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height), resample=Image.Resampling.NEAREST) + + if (image.width * image.height) < image_min_resolution: + resize_factor = math.sqrt(image_min_resolution / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) image = image.resize((width, height), resample=Image.Resampling.NEAREST) @@ -217,14 +224,17 @@ class MMPluginMixin: if len(images) != 0: images = self._regularize_images( - images, image_resolution=getattr(processor, "image_resolution", 768 * 768) + images, + image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), + image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), ) mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: videos = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 256 * 256), + image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), + image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) @@ -606,7 +616,8 @@ class MiniCPMVPlugin(BasePlugin): if len(images) != 0: images = self._regularize_images( images, - image_resolution=getattr(processor, "image_resolution", 768 * 768), + image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), + image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), ) if "valid_image_nums_ls" in kwargs: valid_image_nums_ls = kwargs["valid_image_nums_ls"] @@ -626,7 +637,8 @@ class MiniCPMVPlugin(BasePlugin): if len(videos) != 0: videos = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 256 * 256), + image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), + image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) @@ -774,7 +786,11 @@ class MllamaPlugin(BasePlugin): num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768)) + images = self._regularize_images( + images, + image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), + image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), + ) batch_images = [] for image_length in imglens: batch_images.append(images[:image_length]) @@ -1065,14 +1081,17 @@ class Qwen2vlPlugin(BasePlugin): mm_inputs = {} if len(images) != 0: images = self._regularize_images( - images, image_resolution=getattr(processor, "image_resolution", 768 * 768) + images, + image_max_resolution=getattr(processor, "image_max_resolution", 768 * 768), + image_min_resolution=getattr(processor, "image_min_resolution", 32 * 32), ) mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: videos, fps_per_video = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 256 * 256), + image_max_resolution=getattr(processor, "video_max_resolution", 256 * 256), + image_min_resolution=getattr(processor, "video_min_resolution", 16 * 16), video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 128), ) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index e4429167..f742a646 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -58,14 +58,22 @@ class ProcessorArguments: Arguments pertaining to the image processor. """ - image_resolution: int = field( + image_max_resolution: int = field( default=768 * 768, metadata={"help": "The maximum number of pixels of image inputs."}, ) - video_resolution: int = field( + image_min_resolution: int = field( + default=32 * 32, + metadata={"help": "The minimum number of pixels of image inputs."}, + ) + video_max_resolution: int = field( default=256 * 256, metadata={"help": "The maximum number of pixels of video inputs."}, ) + video_min_resolution: int = field( + default=16 * 16, + metadata={"help": "The minimum number of pixels of video inputs."}, + ) video_fps: float = field( default=2.0, metadata={"help": "The frames to sample per second for video inputs."}, diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1a57b00e..11404e43 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -79,9 +79,11 @@ def patch_processor( setattr(processor, "tokenizer", tokenizer) if getattr(config, "vision_config", None) is not None: # visual models setattr(processor, "image_seqlen", get_image_seqlen(config)) - setattr(processor, "image_resolution", model_args.image_resolution) setattr(processor, "patch_size", get_patch_size(config, processor)) - setattr(processor, "video_resolution", model_args.video_resolution) + setattr(processor, "image_max_resolution", model_args.image_max_resolution) + setattr(processor, "image_min_resolution", model_args.image_min_resolution) + setattr(processor, "video_max_resolution", model_args.video_max_resolution) + setattr(processor, "video_min_resolution", model_args.video_min_resolution) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))