mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
add option to disable version check
Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
This commit is contained in:
parent
62b6a7971a
commit
5f83860aa1
@ -2,11 +2,11 @@ import importlib.metadata
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
|
|
||||||
|
|
||||||
def is_package_available(name: str) -> bool:
|
def _is_package_available(name: str) -> bool:
|
||||||
return importlib.util.find_spec(name) is not None
|
return importlib.util.find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
def get_package_version(name: str) -> str:
|
def _get_package_version(name: str) -> str:
|
||||||
try:
|
try:
|
||||||
return importlib.metadata.version(name)
|
return importlib.metadata.version(name)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -14,36 +14,40 @@ def get_package_version(name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def is_fastapi_availble():
|
def is_fastapi_availble():
|
||||||
return is_package_available("fastapi")
|
return _is_package_available("fastapi")
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn2_available():
|
def is_flash_attn2_available():
|
||||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return is_package_available("jieba")
|
return _is_package_available("jieba")
|
||||||
|
|
||||||
|
|
||||||
def is_matplotlib_available():
|
def is_matplotlib_available():
|
||||||
return is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
def is_nltk_available():
|
def is_nltk_available():
|
||||||
return is_package_available("nltk")
|
return _is_package_available("nltk")
|
||||||
|
|
||||||
|
|
||||||
def is_requests_available():
|
def is_requests_available():
|
||||||
return is_package_available("requests")
|
return _is_package_available("requests")
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return is_package_available("rouge_chinese")
|
return _is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return is_package_available("sse_starlette")
|
return _is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
|
def is_unsloth_available():
|
||||||
|
return _is_package_available("unsloth")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
@ -132,6 +132,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
default="lora", metadata={"help": "Which fine-tuning method to use."}
|
default="lora", metadata={"help": "Which fine-tuning method to use."}
|
||||||
)
|
)
|
||||||
|
disable_version_checking: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to disable version checking."}
|
||||||
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Whether or not to save the training loss curves."}
|
default=False, metadata={"help": "Whether or not to save the training loss curves."}
|
||||||
)
|
)
|
||||||
|
@ -8,8 +8,10 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.packages import is_unsloth_available
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
@ -28,6 +30,14 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
|
|||||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dependencies():
|
||||||
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
|
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
@ -123,8 +133,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||||
|
|
||||||
|
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
||||||
|
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
|
if not finetuning_args.disable_version_checking:
|
||||||
|
_check_dependencies()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
@ -145,7 +161,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# postprocess training_args
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
training_args.local_rank != -1
|
training_args.local_rank != -1
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
@ -158,7 +174,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
can_resume_from_checkpoint = False
|
can_resume_from_checkpoint = False
|
||||||
training_args.resume_from_checkpoint = None
|
if training_args.resume_from_checkpoint is not None:
|
||||||
|
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||||
|
training_args.resume_from_checkpoint = None
|
||||||
else:
|
else:
|
||||||
can_resume_from_checkpoint = True
|
can_resume_from_checkpoint = True
|
||||||
|
|
||||||
@ -194,7 +212,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# postprocess model_args
|
# Post-process model arguments
|
||||||
model_args.compute_dtype = (
|
model_args.compute_dtype = (
|
||||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||||
)
|
)
|
||||||
@ -212,7 +230,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
)
|
)
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
# Set seed before initializing model.
|
|
||||||
transformers.set_seed(training_args.seed)
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
@ -220,24 +237,30 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
if not finetuning_args.disable_version_checking:
|
||||||
|
_check_dependencies()
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
_verify_model_args(model_args, finetuning_args)
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
if not finetuning_args.disable_version_checking:
|
||||||
|
_check_dependencies()
|
||||||
|
|
||||||
transformers.set_seed(eval_args.seed)
|
transformers.set_seed(eval_args.seed)
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|||||||
|
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -21,13 +20,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
|
||||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
|
||||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
@ -63,7 +55,6 @@ def load_model_and_tokenizer(
|
|||||||
|
|
||||||
model = None
|
model = None
|
||||||
if is_trainable and model_args.use_unsloth:
|
if is_trainable and model_args.use_unsloth:
|
||||||
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
|
||||||
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||||
|
|
||||||
unsloth_kwargs = {
|
unsloth_kwargs = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user