add prompt template class

This commit is contained in:
hiyouga
2023-06-07 11:55:25 +08:00
parent 5d021d4ad5
commit 909af8f496
8 changed files with 67 additions and 40 deletions

View File

@@ -29,6 +29,8 @@ from peft import (
get_peft_model
)
from peft.utils import CONFIG_NAME
from trl import AutoModelForCausalLMWithValueHead
from .config import (
@@ -37,10 +39,7 @@ from .config import (
FinetuningArguments
)
from .template import (
prompt_template_alpaca,
prompt_template_ziya
)
from .template import Template
from .other import (
get_logger,
@@ -102,6 +101,9 @@ def _init_adapter(
logger.info("Fine-tuning method: LoRA")
lastest_checkpoint = None
assert os.path.exists(model_args.checkpoint_dir[0], CONFIG_NAME), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if model_args.checkpoint_dir is not None:
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
@@ -401,7 +403,7 @@ def preprocess_data(
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
prompt_template = Template(data_args.prompt_template)
# support question with a single answer or multiple answers
def format_example(examples):
@@ -410,8 +412,7 @@ def preprocess_data(
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = prompt_template(query, examples["history"][i])
prompt = prefix + prompt
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix)
yield prompt, answer
def preprocess_pretrain_dataset(examples):