support rank0 logger

This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent bd08b8c441
commit c38aa29336
42 changed files with 316 additions and 252 deletions

View File

@@ -32,7 +32,7 @@ from transformers.utils import (
)
from transformers.utils.versions import require_version
from .logging import get_logger
from . import logging
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class AverageMeter:
@@ -76,8 +76,8 @@ def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")