mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
support rank0 logger
This commit is contained in:
@@ -14,14 +14,14 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
@@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user