mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(
|
||||
@@ -38,13 +38,15 @@ def configure_attn_implementation(
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed, use eager attention.")
|
||||
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
||||
logger.warning_rank0(
|
||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
@@ -54,13 +56,13 @@ def configure_attn_implementation(
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "flash_attention_2"
|
||||
@@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
if attn_implementation == "flash_attention_2":
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
|
||||
elif attn_implementation == "sdpa":
|
||||
logger.info("Using torch SDPA for faster training and inference.")
|
||||
logger.info_rank0("Using torch SDPA for faster training and inference.")
|
||||
else:
|
||||
logger.info("Using vanilla attention implementation.")
|
||||
logger.info_rank0("Using vanilla attention implementation.")
|
||||
|
||||
@@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import LAYERNORM_NAMES
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
@@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
@@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
logger.info_rank0("Gradient checkpointing enabled.")
|
||||
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
@@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||
@@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
|
||||
logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def apply_liger_kernel(
|
||||
@@ -54,14 +54,14 @@ def apply_liger_kernel(
|
||||
elif model_type == "qwen2_vl":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
||||
else:
|
||||
logger.warning("Current model does not support liger kernel.")
|
||||
logger.warning_rank0("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||
logger.info("Current training stage does not support chunked cross entropy.")
|
||||
logger.info_rank0("Current training stage does not support chunked cross entropy.")
|
||||
kwargs = {"fused_linear_cross_entropy": False}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
Cache,
|
||||
LlamaAttention,
|
||||
@@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
@@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
if not is_trainable or not model_args.shift_attn:
|
||||
return
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
_apply_llama_patch()
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
logger.warning_rank0("Current model does not support shift short attention.")
|
||||
|
||||
@@ -14,14 +14,14 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
@@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
@@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
|
||||
@@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import FILEEXT2TYPE
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
@@ -109,7 +109,7 @@ def configure_quantization(
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
@@ -130,7 +130,7 @@ def configure_quantization(
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
quant_bits = quantization_config.get("bits", "?")
|
||||
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
@@ -149,7 +149,7 @@ def configure_quantization(
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
@@ -179,7 +179,7 @@ def configure_quantization(
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
@@ -191,7 +191,7 @@ def configure_quantization(
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
@@ -201,4 +201,4 @@ def configure_quantization(
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
@@ -36,26 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
return
|
||||
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
logger.warning_rank0(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")
|
||||
logger.info_rank0(
|
||||
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
from ...extras.misc import get_current_device
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _get_unsloth_kwargs(
|
||||
@@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
|
||||
try:
|
||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||
except NotImplementedError:
|
||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||
model = None
|
||||
model_args.use_unsloth = False
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
|
||||
import torch
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
@@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
||||
except Exception as err:
|
||||
err_text = str(err)
|
||||
|
||||
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,8 +31,8 @@ if TYPE_CHECKING:
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
@@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user