use fp16 model, add logcallback

This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 769c6ab56b
commit 0c9fda01e3
7 changed files with 112 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
LlamaConfig,
LlamaForCausalLM,
LlamaTokenizer,
HfArgumentParser,
@@ -151,7 +152,7 @@ def load_pretrained(
use_fast=model_args.use_fast_tokenizer,
padding_side="left"
)
tokenizer.pad_token_id = 0 # set as the <unk> token
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
# Quantization configurations (using bitsandbytes library).
config_kwargs = {}
@@ -168,8 +169,15 @@ def load_pretrained(
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
# Load and prepare pretrained models (without valuehead).
model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
model = LlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the llama weights are float16 type
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)