mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
add prompt template class
This commit is contained in:
@@ -29,6 +29,8 @@ from peft import (
|
||||
get_peft_model
|
||||
)
|
||||
|
||||
from peft.utils import CONFIG_NAME
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from .config import (
|
||||
@@ -37,10 +39,7 @@ from .config import (
|
||||
FinetuningArguments
|
||||
)
|
||||
|
||||
from .template import (
|
||||
prompt_template_alpaca,
|
||||
prompt_template_ziya
|
||||
)
|
||||
from .template import Template
|
||||
|
||||
from .other import (
|
||||
get_logger,
|
||||
@@ -102,6 +101,9 @@ def _init_adapter(
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
lastest_checkpoint = None
|
||||
|
||||
assert os.path.exists(model_args.checkpoint_dir[0], CONFIG_NAME), \
|
||||
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
@@ -401,7 +403,7 @@ def preprocess_data(
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def format_example(examples):
|
||||
@@ -410,8 +412,7 @@ def preprocess_data(
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
if examples["query"][i]:
|
||||
query += "\n" + examples["query"][i]
|
||||
prompt = prompt_template(query, examples["history"][i])
|
||||
prompt = prefix + prompt
|
||||
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix)
|
||||
yield prompt, answer
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
|
||||
Reference in New Issue
Block a user