hiyouga 5ee1bdecdc add MMLU and C-Eval script
Former-commit-id: 465ee8119aa489a41bee0b01b3c105a2f3dd137f
2023-09-23 00:34:17 +08:00

174 lines
5.9 KiB
Python

# coding=utf-8
# Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os
import fire
import json
import torch
import numpy as np
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset
from dataclasses import dataclass
from llmtuner import ChatModel
if TYPE_CHECKING:
from datasets import Dataset
choices = ["A", "B", "C", "D"]
@dataclass
class EvalTemplate:
system: str
choice: str
answer: str
def parse_example(
self,
example: Dict[str, str]
) -> Tuple[str, str]:
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example(
self,
target_data: Dict[str, str],
support_set: "Dataset",
subject_name: str,
use_history: bool
) -> Tuple[str, str, List[Tuple[str, str]]]:
query, resp = self.parse_example(target_data)
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history):
temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else:
query = self.system.format(subject=subject_name) + query
if not use_history:
query = "\n\n".join(["".join(item) for item in history] + [query])
history = []
return query, resp, history
eval_templates = {
"en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}",
answer="\nAnswer: "
),
"zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}",
answer="\n答案:"
)
}
@torch.inference_mode()
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor]
) -> List[str]:
logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
],
dim=-1
),
dim=-1
).detach()
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
def evaluate(
model_name_or_path: str,
finetuning_type: Optional[str] = "lora",
checkpoint_dir: Optional[str] = None,
template: Optional[str] = "vanilla",
task: Optional[str] = "ceval",
dataset_dir: Optional[str] = "evaluation",
split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5,
batch_size: Optional[int] = 4
):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f)
chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
checkpoint_dir=checkpoint_dir,
template=template
))
chat_model.tokenizer.padding_side = "left"
eval_template = eval_templates[lang]
category_corrects: Dict[str, np.ndarray] = {
"STEM": np.array([], dtype="bool"),
"Social Sciences": np.array([], dtype="bool"),
"Humanities": np.array([], dtype="bool"),
"Other": np.array([], dtype="bool")
}
overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys())
for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])):
query, resp, history = eval_template.format_example(
target_data=dataset[split][i],
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))),
subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history
)
input_ids, _ = chat_model.template.encode_oneturn(
tokenizer=chat_model.tokenizer,
query=query,
resp=resp,
history=history
)
inputs.append({
"input_ids": input_ids,
"attention_mask": [1] * len(input_ids)
})
labels.append(resp)
outputs = []
for i in range(0, len(inputs), batch_size):
batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size],
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)))
for category_name, category_correct in category_corrects.items():
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct)))
if __name__ == "__main__":
fire.Fire(evaluate)