Update evaluate.py

Former-commit-id: 40201845b707c8b888b17744622ad78b4fb08a09
This commit is contained in:
hiyouga 2023-06-26 23:41:33 +08:00
parent fe7ca5cb63
commit a7e53dcfef

View File

@ -1,6 +1,7 @@
# coding=utf-8 # coding=utf-8
# Evaluates fine-tuned models automatically. # Evaluates fine-tuned models automatically.
# Usage: python evaluate.py --evalset ceval/ceval-exam:law --split dev --api_base http://localhost:8000/v1 --task_type choice # Usage: python evaluate.py --evalset ceval/ceval-exam:law --split dev --api_base http://localhost:8000/v1 --task_type choice
# dataset format: question (string), A (string), B (string), C (string), D (string), answer Literal["A", "B", "C", "D"]
import os import os
@ -12,13 +13,6 @@ from typing import Literal, Optional
from datasets import load_dataset from datasets import load_dataset
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json"
}
def format_example_choice(examples): def format_example_choice(examples):
model_inputs = {"query": [], "label": []} model_inputs = {"query": [], "label": []}
task_template = "请从ABCD四个选项中选出正确的选项仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:" task_template = "请从ABCD四个选项中选出正确的选项仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:"
@ -53,9 +47,28 @@ def format_example_cloze(examples):
return model_inputs return model_inputs
def format_example_openqa(examples):
model_inputs = {"query": [], "label": []}
task_template = "回答以下问题:{question}\n答案:"
for i in range(len(examples["id"])):
query = task_template.format(question=examples["question"][i])
label = examples[examples["answer"][i]][i]
model_inputs["query"].append(query)
model_inputs["label"].append(label)
return model_inputs
TASK_DICT = { TASK_DICT = {
"choice": format_example_choice, "choice": format_example_choice,
"cloze": format_example_cloze "cloze": format_example_cloze,
"openqa": format_example_openqa
}
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json"
} }
@ -63,7 +76,7 @@ def evaluate(
evalset: str, evalset: str,
api_base: str, api_base: str,
split: Optional[str] = "val", split: Optional[str] = "val",
task_type: Optional[Literal["choice", "cloze"]] = "choice", task_type: Optional[Literal["choice", "cloze", "openqa"]] = "choice",
n_samples: Optional[int] = 20 n_samples: Optional[int] = 20
): ):
@ -72,12 +85,11 @@ def evaluate(
if os.path.isfile(evalset): if os.path.isfile(evalset):
dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"] dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"]
elif ":" in evalset:
evalset, subset = evalset.split(":")
dataset = load_dataset(evalset, subset, split=split)
else: else:
if ":" in evalset: dataset = load_dataset(evalset, split=split)
evalset, subset = evalset.split(":")
dataset = load_dataset(evalset, subset, split=split)
else:
dataset = load_dataset(evalset, split=split)
n_samples = min(len(dataset), n_samples) n_samples = min(len(dataset), n_samples)
@ -87,12 +99,12 @@ def evaluate(
n_correct = 0 n_correct = 0
predictions = [] predictions = []
for example in tqdm(dataset): for example in tqdm(dataset):
query = example["query"] query, label = example["query"], example["label"]
label = example["label"]
predict = openai.ChatCompletion.create( predict = openai.ChatCompletion.create(
model="main", model="default",
messages=[{"role": "user", "content": query}], messages=[{"role": "user", "content": query}],
temperature=0.01, temperature=0.01,
top_p=0.01,
max_new_tokens=20 max_new_tokens=20
).choices[0].message.content ).choices[0].message.content
@ -100,6 +112,8 @@ def evaluate(
n_correct += 1 n_correct += 1
if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]: if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]:
n_correct += 1 n_correct += 1
if task_type == "openqa" and label in predict:
n_correct += 1
predictions.append({ predictions.append({
"query": query, "query": query,