From 3720618c63e82961fa15be4e3364fcd0897167ec Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 30 Sep 2024 17:07:43 +0800 Subject: [PATCH] add patch processor func Former-commit-id: 45841bb646afa9d0bc2ea4b6b7b107daa67d90f0 --- src/llamafactory/model/loader.py | 15 +++------------ src/llamafactory/model/patcher.py | 26 ++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index c90913ae..9e47fb72 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -25,8 +25,7 @@ from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.valuehead import load_valuehead_params -from .model_utils.visual import get_image_seqlen, get_patch_size, get_vision_feature_select_strategy -from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model +from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model if TYPE_CHECKING: @@ -61,7 +60,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": r""" - Loads pretrained tokenizer. + Loads pretrained tokenizer and optionally loads processor. Note: including inplace operation of model_args. """ @@ -94,17 +93,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": logger.warning("New tokens have been added, changed `resize_vocab` to True.") patch_tokenizer(tokenizer) - try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) - setattr(processor, "tokenizer", tokenizer) - setattr(processor, "image_seqlen", get_image_seqlen(config)) - setattr(processor, "image_resolution", model_args.image_resolution) - setattr(processor, "patch_size", get_patch_size(config)) - setattr(processor, "video_resolution", model_args.video_resolution) - setattr(processor, "video_fps", model_args.video_fps) - setattr(processor, "video_maxlen", model_args.video_maxlen) - setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config)) + patch_processor(processor, config, tokenizer, model_args) except Exception: processor = None diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 3de82703..e4bb7ac1 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -34,11 +34,17 @@ from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model -from .model_utils.visual import autocast_projector_dtype, configure_visual_model +from .model_utils.visual import ( + autocast_projector_dtype, + configure_visual_model, + get_image_seqlen, + get_patch_size, + get_vision_feature_select_strategy, +) if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedTokenizer, ProcessorMixin from trl import AutoModelForCausalLMWithValueHead from ..hparams import ModelArguments @@ -52,6 +58,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) +def patch_processor( + processor: "ProcessorMixin", + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", +) -> None: + setattr(processor, "tokenizer", tokenizer) + setattr(processor, "image_seqlen", get_image_seqlen(config)) + setattr(processor, "image_resolution", model_args.image_resolution) + setattr(processor, "patch_size", get_patch_size(config)) + setattr(processor, "video_resolution", model_args.video_resolution) + setattr(processor, "video_fps", model_args.video_fps) + setattr(processor, "video_maxlen", model_args.video_maxlen) + setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config)) + + def patch_config( config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer",