support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -15,7 +15,7 @@
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def apply_liger_kernel(
@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
logger.warning_rank0("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")
logger.info_rank0("Liger kernel has been applied to the model.")