support rank0 logger

This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent bd08b8c441
commit c38aa29336
42 changed files with 316 additions and 252 deletions

View File

@@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
@@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
@@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names