support rank0 logger

Former-commit-id: c38aa29336
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent 4b2c47fcae
commit e83cb17f97
42 changed files with 316 additions and 252 deletions

View File

@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _setup_full_tuning(
@@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Full")
logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Freeze")
logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
@@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
@@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
@@ -190,7 +190,7 @@ def _setup_lora_tuning(
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
@@ -236,10 +236,10 @@ def _setup_lora_tuning(
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
@@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":

View File

@@ -18,7 +18,7 @@ import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
class TokenizerModule(TypedDict):
@@ -90,10 +90,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
dict(additional_special_tokens=model_args.new_special_tokens),
replace_additional_special_tokens=False,
)
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
if num_added_tokens > 0 and not model_args.resize_vocab:
model_args.resize_vocab = True
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try:
@@ -180,7 +180,7 @@ def load_model(
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
if not is_trainable:
model.requires_grad_(False)
@@ -200,7 +200,7 @@ def load_model(
else:
param_stats = f"all params: {all_param:,}"
logger.info(param_stats)
logger.info_rank0(param_stats)
if model_args.print_param_status:
for name, param in model.named_parameters():

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def configure_attn_implementation(
@@ -38,13 +38,15 @@ def configure_attn_implementation(
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:
logger.warning("FlashAttention-2 is not installed, use eager attention.")
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa":
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
)
if model_args.flash_attn == "auto":
return
@@ -54,13 +56,13 @@ def configure_attn_implementation(
elif model_args.flash_attn == "sdpa":
if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
logger.warning_rank0("FlashAttention-2 is not installed.")
return
requested_attn_implementation = "flash_attention_2"
@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
attn_implementation = getattr(config, "_attn_implementation", None)
if attn_implementation == "flash_attention_2":
logger.info("Using FlashAttention-2 for faster training and inference.")
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
logger.info_rank0("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla attention implementation.")
logger.info_rank0("Using vanilla attention implementation.")

View File

@@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
@@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
logger.warning_rank0("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
@@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.")
logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)

View File

@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -15,7 +15,7 @@
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def apply_liger_kernel(
@@ -54,14 +54,14 @@ def apply_liger_kernel(
elif model_type == "qwen2_vl":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
logger.warning_rank0("Current model does not support liger kernel.")
return
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
logger.info_rank0("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")
logger.info_rank0("Liger kernel has been applied to the model.")

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
import torch
import torch.nn as nn
import transformers
from transformers.models.llama.modeling_llama import (
Cache,
LlamaAttention,
@@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
transformers_logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
# Modified from:
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
if not is_trainable or not model_args.shift_attn:
return
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
_apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
logger.warning_rank0("Current model does not support shift short attention.")

View File

@@ -14,14 +14,14 @@
from typing import TYPE_CHECKING, List
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
@@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
return list(module_names)
@@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
):
module_names.append(name)
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
return module_names

View File

@@ -43,8 +43,8 @@ import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
model_type = getattr(config, "model_type", None)
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
_patch_for_block_diag_attn(model_type)
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
else:
raise ValueError("Current model does not support block diagonal attention.")

View File

@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import FILEEXT2TYPE
from ...extras.logging import get_logger
from ...extras.misc import get_current_device
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
@unique
@@ -109,7 +109,7 @@ def configure_quantization(
"""
if getattr(config, "quantization_config", None): # ptq
if model_args.quantization_bit is not None:
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
@@ -130,7 +130,7 @@ def configure_quantization(
quantization_config["bits"] = 2
quant_bits = quantization_config.get("bits", "?")
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
elif model_args.export_quantization_bit is not None: # auto-gptq
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
@@ -149,7 +149,7 @@ def configure_quantization(
)
init_kwargs["device_map"] = "auto"
init_kwargs["max_memory"] = get_max_memory()
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
elif model_args.quantization_bit is not None: # on-the-fly
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
@@ -179,7 +179,7 @@ def configure_quantization(
else:
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
@@ -191,7 +191,7 @@ def configure_quantization(
init_kwargs["quantization_config"] = HqqConfig(
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
) # use ATEN kernel (axis=0) for performance
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
if model_args.quantization_bit != 8:
raise ValueError("EETQ only accepts 8-bit quantization.")
@@ -201,4 +201,4 @@ def configure_quantization(
require_version("eetq", "To fix: pip install eetq")
init_kwargs["quantization_config"] = EetqConfig()
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")

View File

@@ -19,7 +19,7 @@
import math
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
@@ -36,26 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
return
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
logger.warning_rank0("Current model does not support RoPE scaling.")
return
if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic":
logger.warning(
logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")
logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
)

View File

@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from ...extras.logging import get_logger
from ...extras import logging
from ...extras.misc import get_current_device
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _get_unsloth_kwargs(
@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
except NotImplementedError:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model = None
model_args.use_unsloth = False

View File

@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
import torch
from transformers.utils import cached_file
from ...extras import logging
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ...extras.logging import get_logger
if TYPE_CHECKING:
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
except Exception as err:
err_text = str(err)
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
return None

View File

@@ -18,11 +18,11 @@
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
import torch
import transformers
import transformers.models
from transformers.activations import ACT2FN
from transformers.utils import logging
from ...extras.logging import get_logger
from ...extras import logging
if TYPE_CHECKING:
@@ -31,8 +31,8 @@ if TYPE_CHECKING:
from ...hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
transformers_logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)
transformers_logger = transformers.utils.logging.get_logger(__name__)
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
@@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
else:
return
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
@@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL

View File

@@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras import logging
from ..extras.misc import infer_optim_dtype
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
@@ -100,7 +100,7 @@ def patch_config(
if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
logger.info("Using KV cache for faster generation.")
logger.info_rank0("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
@@ -165,7 +165,7 @@ def patch_model(
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
logger.warning_rank0("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: