mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 18:32:14 +08:00
137 lines
5.7 KiB
Python
137 lines
5.7 KiB
Python
from typing import TYPE_CHECKING, Optional, Tuple
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
|
from transformers.utils.versions import require_version
|
|
from trl import AutoModelForCausalLMWithValueHead
|
|
|
|
from llmtuner.extras.logging import get_logger
|
|
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
|
from llmtuner.extras.packages import is_flash_attn2_available
|
|
from llmtuner.hparams import FinetuningArguments
|
|
from llmtuner.model.adapter import init_adapter
|
|
from llmtuner.model.patches import patch_config, patch_model, patch_valuehead_model, patch_tokenizer, register_autoclass
|
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
from llmtuner.hparams import ModelArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.1")
|
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
|
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
|
|
|
|
|
def load_model_and_tokenizer(
|
|
model_args: "ModelArguments",
|
|
finetuning_args: "FinetuningArguments",
|
|
is_trainable: Optional[bool] = False,
|
|
add_valuehead: Optional[bool] = False
|
|
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
|
|
r"""
|
|
Loads pretrained model and tokenizer.
|
|
|
|
Support both training and inference.
|
|
"""
|
|
|
|
try_download_model_from_ms(model_args)
|
|
|
|
config_kwargs = {
|
|
"trust_remote_code": True,
|
|
"cache_dir": model_args.cache_dir,
|
|
"revision": model_args.model_revision,
|
|
"token": model_args.hf_hub_token
|
|
}
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
use_fast=model_args.use_fast_tokenizer,
|
|
split_special_tokens=model_args.split_special_tokens,
|
|
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
|
**config_kwargs
|
|
)
|
|
patch_tokenizer(tokenizer)
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
|
patch_config(config, model_args, is_trainable)
|
|
|
|
# Set FlashAttention-2
|
|
if model_args.flash_attn and is_flash_attn2_available():
|
|
config_kwargs["use_flash_attention_2"] = True
|
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
|
|
# Quantization configurations (using gptq or awq)
|
|
if getattr(config, "quantization_config", None):
|
|
model_args.quantization_bit = None # remove bnb quantization
|
|
config_kwargs["device_map"] = {"": get_current_device()}
|
|
quantization_config = getattr(config, "quantization_config", None)
|
|
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
|
|
|
# Quantization configurations (using bitsandbytes)
|
|
if model_args.quantization_bit is not None:
|
|
if is_deepspeed_zero3_enabled():
|
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
|
|
|
if model_args.quantization_bit == 8:
|
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
if model_args.quantization_bit == 4:
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
bnb_4bit_quant_type=model_args.quantization_type
|
|
)
|
|
|
|
config_kwargs["device_map"] = {"": get_current_device()}
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
|
|
# Load pre-trained models (without valuehead)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
config=config,
|
|
torch_dtype=model_args.compute_dtype,
|
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
|
**config_kwargs
|
|
)
|
|
patch_model(model)
|
|
register_autoclass(config, model, tokenizer)
|
|
resize_embedding_layer(model, tokenizer)
|
|
|
|
# Initialize adapters
|
|
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
|
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
|
|
|
# Prepare model with valuehead for RLHF
|
|
if add_valuehead:
|
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
|
patch_valuehead_model(model)
|
|
vhead_params = load_valuehead_params(model_args)
|
|
if vhead_params is not None:
|
|
model.load_state_dict(vhead_params, strict=False)
|
|
|
|
# Prepare model for inference
|
|
if not is_trainable:
|
|
model.requires_grad_(False) # fix all model params
|
|
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
|
model.eval()
|
|
else:
|
|
model.train()
|
|
|
|
trainable_params, all_param = count_parameters(model)
|
|
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
|
trainable_params, all_param, 100 * trainable_params / all_param
|
|
))
|
|
|
|
if not is_trainable:
|
|
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
|
|
|
return model, tokenizer
|