add ziya prompt template

This commit is contained in:
hiyouga
2023-06-03 19:05:51 +08:00
parent 771f454ff1
commit de09ee1315
6 changed files with 79 additions and 24 deletions

View File

@@ -264,6 +264,18 @@ def prepare_args(
return model_args, data_args, training_args, finetuning_args
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
return model_args, data_args, finetuning_args
def prepare_data(
model_args: ModelArguments,
data_args: DataTrainingArguments
@@ -347,7 +359,8 @@ def preprocess_data(
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
def format_example(examples): # support question with a single answer or multiple answers
# support question with a single answer or multiple answers
def format_example_alpaca(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
@@ -357,12 +370,27 @@ def preprocess_data(
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix
if examples["history"][i]:
history = examples["history"][i]
for old_query, response in history:
for old_query, response in examples["history"][i]:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer
def format_example_ziya(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = ""
if examples["history"][i]:
for old_query, response in examples["history"][i]:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
prompt = prefix + prompt
yield prompt, answer
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
text_ids = tokenizer(examples["prompt"])["input_ids"]