mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
@@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
_apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
logger.warning_rank0("Current model does not support shift short attention.")
|
||||
|
||||
Reference in New Issue
Block a user