mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support LongLoRA
This commit is contained in:
@@ -13,17 +13,19 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
@@ -73,10 +75,6 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2)
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
@@ -84,10 +82,15 @@ def load_model_and_tokenizer(
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
|
||||
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
|
||||
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
@@ -103,7 +106,6 @@ def load_model_and_tokenizer(
|
||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
@@ -126,17 +128,23 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Set flash attention
|
||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
||||
import transformers.models.llama.modeling_llama as LlamaModule
|
||||
import llmtuner.extras.patches.flash_llama as FlashLlama
|
||||
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
|
||||
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
|
||||
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
|
||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
setattr(config, "pretraining_tp", 1)
|
||||
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
|
||||
LlamaPatches._prepare_decoder_attention_mask
|
||||
)
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
@@ -172,12 +180,20 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(model, "shift_ratio", 0.25)
|
||||
logger.info("Using shift short attention proposed by LongLoRA.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
|
||||
Reference in New Issue
Block a user