mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
use fp16 model, add logcallback
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
HfArgumentParser,
|
||||
@@ -151,7 +152,7 @@ def load_pretrained(
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side="left"
|
||||
)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
config_kwargs = {}
|
||||
@@ -168,8 +169,15 @@ def load_pretrained(
|
||||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # the llama weights are float16 type
|
||||
**config_kwargs
|
||||
)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user