From 9c69cc1e16a502c479c37b5e9523f463d851dae3 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 26 Apr 2024 03:22:40 +0800 Subject: [PATCH] Update loader.py Former-commit-id: 7d812ed8419817a8724f8584996607e8685b5ea9 --- src/llmtuner/model/loader.py | 54 ++++++++++++++---------------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 5b5c0a4d..0ff7a350 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,12 +1,6 @@ -from typing import TYPE_CHECKING, Any, Dict, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForVision2Seq, - AutoProcessor, - AutoTokenizer, -) +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger @@ -19,13 +13,19 @@ from .utils.unsloth import load_unsloth_pretrained_model if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from ..hparams import FinetuningArguments, ModelArguments + logger = get_logger(__name__) +class TokenizerModule(TypedDict): + tokenizer: "PreTrainedTokenizer" + processor: Optional["ProcessorMixin"] + + def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: r""" Gets arguments to load config/tokenizer/model. @@ -41,7 +41,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: } -def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]: +def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": r""" Loads pretrained tokenizer. @@ -75,25 +75,14 @@ def load_tokenizer(model_args: "ModelArguments") -> Dict[str, Union["PreTrainedT logger.warning("New tokens have been added, changed `resize_vocab` to True.") patch_tokenizer(tokenizer) - tokenizer_modules = {"tokenizer": tokenizer, "processor": None} - if model_args.use_mllm: - try: - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - use_fast=model_args.use_fast_tokenizer, - split_special_tokens=model_args.split_special_tokens, - padding_side="right", - **init_kwargs, - ) - except Exception: # try the fast one - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - use_fast=True, - padding_side="right", - **init_kwargs, - ) - tokenizer_modules["processor"] = processor - return tokenizer_modules + + if model_args.visual_inputs: + processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) + setattr(processor, "tokenizer", tokenizer) + else: + processor = None + + return {"tokenizer": tokenizer, "processor": processor} def load_config(model_args: "ModelArguments") -> "PretrainedConfig": @@ -132,11 +121,10 @@ def load_model( if model_args.mixture_of_depths == "load": model = load_mod_pretrained_model(**init_kwargs) + elif model_args.visual_inputs: + model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) else: - if model_args.use_mllm: - model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) - else: - model = AutoModelForCausalLM.from_pretrained(**init_kwargs) + model = AutoModelForCausalLM.from_pretrained(**init_kwargs) if model_args.mixture_of_depths == "convert": model = convert_pretrained_model_to_mod(model, config, model_args)