fix version checking

This commit is contained in:
hiyouga
2024-03-06 14:51:51 +08:00
parent d1587c80de
commit 3016e65657
18 changed files with 49 additions and 33 deletions

View File

@@ -14,6 +14,7 @@ from transformers.utils import (
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.versions import require_version
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
@@ -56,6 +57,17 @@ class AverageMeter:
self.avg = self.sum / self.count
def check_dependencies() -> None:
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")):
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.