diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 71893475..29323885 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -2,11 +2,11 @@ import importlib.metadata import importlib.util -def is_package_available(name: str) -> bool: +def _is_package_available(name: str) -> bool: return importlib.util.find_spec(name) is not None -def get_package_version(name: str) -> str: +def _get_package_version(name: str) -> str: try: return importlib.metadata.version(name) except Exception: @@ -14,36 +14,40 @@ def get_package_version(name: str) -> str: def is_fastapi_availble(): - return is_package_available("fastapi") + return _is_package_available("fastapi") def is_flash_attn2_available(): - return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") + return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") def is_jieba_available(): - return is_package_available("jieba") + return _is_package_available("jieba") def is_matplotlib_available(): - return is_package_available("matplotlib") + return _is_package_available("matplotlib") def is_nltk_available(): - return is_package_available("nltk") + return _is_package_available("nltk") def is_requests_available(): - return is_package_available("requests") + return _is_package_available("requests") def is_rouge_available(): - return is_package_available("rouge_chinese") + return _is_package_available("rouge_chinese") def is_starlette_available(): - return is_package_available("sse_starlette") + return _is_package_available("sse_starlette") + + +def is_unsloth_available(): + return _is_package_available("unsloth") def is_uvicorn_available(): - return is_package_available("uvicorn") + return _is_package_available("uvicorn") diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 3c336574..7c143918 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -132,6 +132,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( default="lora", metadata={"help": "Which fine-tuning method to use."} ) + disable_version_checking: Optional[bool] = field( + default=False, metadata={"help": "Whether or not to disable version checking."} + ) plot_loss: Optional[bool] = field( default=False, metadata={"help": "Whether or not to save the training loss curves."} ) diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index b56e7a18..a09f84bc 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -8,8 +8,10 @@ import torch import transformers from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers.trainer_utils import get_last_checkpoint +from transformers.utils.versions import require_version from ..extras.logging import get_logger +from ..extras.packages import is_unsloth_available from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -28,6 +30,14 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] +def _check_dependencies(): + require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") + require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") + require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") + require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") + require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") + + def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) @@ -123,8 +133,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: raise ValueError("Please specify `lora_target` in LoRA training.") + if training_args.do_train and model_args.use_unsloth and not is_unsloth_available: + raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth") + _verify_model_args(model_args, finetuning_args) + if not finetuning_args.disable_version_checking: + _check_dependencies() + if ( training_args.do_train and finetuning_args.finetuning_type == "lora" @@ -145,7 +161,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: logger.warning("Specify `ref_model` for computing rewards at evaluation.") - # postprocess training_args + # Post-process training arguments if ( training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None @@ -158,7 +174,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: can_resume_from_checkpoint = False - training_args.resume_from_checkpoint = None + if training_args.resume_from_checkpoint is not None: + logger.warning("Cannot resume from checkpoint in current stage.") + training_args.resume_from_checkpoint = None else: can_resume_from_checkpoint = True @@ -194,7 +212,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ) ) - # postprocess model_args + # Post-process model arguments model_args.compute_dtype = ( torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) ) @@ -212,7 +230,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ) logger.info(f"Training/evaluation parameters {training_args}") - # Set seed before initializing model. transformers.set_seed(training_args.seed) return model_args, data_args, training_args, finetuning_args, generating_args @@ -220,24 +237,30 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) + _set_transformers_logging() + _verify_model_args(model_args, finetuning_args) if data_args.template is None: raise ValueError("Please specify which `template` to use.") - _verify_model_args(model_args, finetuning_args) + if not finetuning_args.disable_version_checking: + _check_dependencies() return model_args, data_args, finetuning_args, generating_args def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) + _set_transformers_logging() + _verify_model_args(model_args, finetuning_args) if data_args.template is None: raise ValueError("Please specify which `template` to use.") - _verify_model_args(model_args, finetuning_args) + if not finetuning_args.disable_version_checking: + _check_dependencies() transformers.set_seed(eval_args.seed) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index ed86a1c0..deeaed40 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.integrations import is_deepspeed_zero3_enabled -from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger @@ -21,13 +20,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") -require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") -require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") -require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") -require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") - - def load_model_and_tokenizer( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", @@ -63,7 +55,6 @@ def load_model_and_tokenizer( model = None if is_trainable and model_args.use_unsloth: - require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth") from unsloth import FastLlamaModel, FastMistralModel # type: ignore unsloth_kwargs = {