mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 18:56:01 +08:00
70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
|
#
|
|
# This code is inspired by the HuggingFace's transformers library.
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import importlib.metadata
|
|
import importlib.util
|
|
from functools import lru_cache
|
|
from typing import TYPE_CHECKING
|
|
|
|
from packaging import version
|
|
from transformers.utils.versions import require_version
|
|
|
|
from . import logging
|
|
from .env import is_env_enabled
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from packaging.version import Version
|
|
|
|
|
|
def _is_package_available(name: str) -> bool:
|
|
return importlib.util.find_spec(name) is not None
|
|
|
|
|
|
def _get_package_version(name: str) -> "Version":
|
|
try:
|
|
return version.parse(importlib.metadata.version(name))
|
|
except Exception:
|
|
return version.parse("0.0.0")
|
|
|
|
|
|
@lru_cache
|
|
def is_transformers_version_greater_than(content: str):
|
|
return _get_package_version("transformers") >= version.parse(content)
|
|
|
|
|
|
def check_version(requirement: str, mandatory: bool = False) -> None:
|
|
r"""Optionally check the package version."""
|
|
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
|
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
|
return
|
|
|
|
if "gptqmodel" in requirement or "autoawq" in requirement:
|
|
pip_command = f"pip install {requirement} --no-build-isolation"
|
|
else:
|
|
pip_command = f"pip install {requirement}"
|
|
|
|
if mandatory:
|
|
hint = f"To fix: run `{pip_command}`."
|
|
else:
|
|
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
|
|
|
require_version(requirement, hint)
|