mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
[breaking] support transformers 4.48 (#6628)
This commit is contained in:
@@ -34,6 +34,7 @@ from transformers.utils import (
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from . import logging
|
||||
from .packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
@@ -93,11 +94,13 @@ def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
check_version("transformers>=4.41.2,<=4.46.1")
|
||||
check_version("datasets>=2.16.0,<=3.1.0")
|
||||
check_version("accelerate>=0.34.0,<=1.0.1")
|
||||
check_version("transformers>=4.41.2,<=4.48.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.2.0")
|
||||
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||
check_version("peft>=0.11.1,<=0.12.0")
|
||||
check_version("trl>=0.8.6,<=0.9.6")
|
||||
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
|
||||
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
|
||||
@@ -87,11 +87,6 @@ def is_transformers_version_greater_than(content: str):
|
||||
return _get_package_version("transformers") >= version.parse(content)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _is_package_available("uvicorn")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user