mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
166
evaluation/ceval/ceval.py
Normal file
166
evaluation/ceval/ceval.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{huang2023ceval,
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
|
||||
journal={arXiv preprint arXiv:2305.08322},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://cevalbenchmark.com"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "ceval.zip"
|
||||
|
||||
task_list = [
|
||||
"computer_network",
|
||||
"operating_system",
|
||||
"computer_architecture",
|
||||
"college_programming",
|
||||
"college_physics",
|
||||
"college_chemistry",
|
||||
"advanced_mathematics",
|
||||
"probability_and_statistics",
|
||||
"discrete_mathematics",
|
||||
"electrical_engineer",
|
||||
"metrology_engineer",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_chemistry",
|
||||
"high_school_biology",
|
||||
"middle_school_mathematics",
|
||||
"middle_school_biology",
|
||||
"middle_school_physics",
|
||||
"middle_school_chemistry",
|
||||
"veterinary_medicine",
|
||||
"college_economics",
|
||||
"business_administration",
|
||||
"marxism",
|
||||
"mao_zedong_thought",
|
||||
"education_science",
|
||||
"teacher_qualification",
|
||||
"high_school_politics",
|
||||
"high_school_geography",
|
||||
"middle_school_politics",
|
||||
"middle_school_geography",
|
||||
"modern_chinese_history",
|
||||
"ideological_and_moral_cultivation",
|
||||
"logic",
|
||||
"law",
|
||||
"chinese_language_and_literature",
|
||||
"art_studies",
|
||||
"professional_tour_guide",
|
||||
"legal_professional",
|
||||
"high_school_chinese",
|
||||
"high_school_history",
|
||||
"middle_school_history",
|
||||
"civil_servant",
|
||||
"sports_science",
|
||||
"plant_protection",
|
||||
"basic_medicine",
|
||||
"clinical_medicine",
|
||||
"urban_and_rural_planner",
|
||||
"accountant",
|
||||
"fire_engineer",
|
||||
"environmental_impact_assessment_engineer",
|
||||
"tax_accountant",
|
||||
"physician",
|
||||
]
|
||||
|
||||
|
||||
class CevalExamConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class CevalExam(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CevalExamConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"id": datasets.Value("int32"),
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
"explanation": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
if "answer" not in instance.keys():
|
||||
instance["answer"] = ""
|
||||
if "explanation" not in instance.keys():
|
||||
instance["explanation"] = ""
|
||||
yield i, instance
|
||||
BIN
evaluation/ceval/ceval.zip
Normal file
BIN
evaluation/ceval/ceval.zip
Normal file
Binary file not shown.
210
evaluation/ceval/mapping.json
Normal file
210
evaluation/ceval/mapping.json
Normal file
@@ -0,0 +1,210 @@
|
||||
{
|
||||
"accountant": {
|
||||
"name": "注册会计师",
|
||||
"category": "Other"
|
||||
},
|
||||
"advanced_mathematics": {
|
||||
"name": "高等数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"art_studies": {
|
||||
"name": "艺术学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"basic_medicine": {
|
||||
"name": "基础医学",
|
||||
"category": "Other"
|
||||
},
|
||||
"business_administration": {
|
||||
"name": "工商管理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"chinese_language_and_literature": {
|
||||
"name": "中国语言文学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"civil_servant": {
|
||||
"name": "公务员",
|
||||
"category": "Other"
|
||||
},
|
||||
"clinical_medicine": {
|
||||
"name": "临床医学",
|
||||
"category": "Other"
|
||||
},
|
||||
"college_chemistry": {
|
||||
"name": "大学化学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_economics": {
|
||||
"name": "大学经济学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"college_physics": {
|
||||
"name": "大学物理",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_programming": {
|
||||
"name": "大学编程",
|
||||
"category": "STEM"
|
||||
},
|
||||
"computer_architecture": {
|
||||
"name": "计算机组成",
|
||||
"category": "STEM"
|
||||
},
|
||||
"computer_network": {
|
||||
"name": "计算机网络",
|
||||
"category": "STEM"
|
||||
},
|
||||
"discrete_mathematics": {
|
||||
"name": "离散数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"education_science": {
|
||||
"name": "教育学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"electrical_engineer": {
|
||||
"name": "注册电气工程师",
|
||||
"category": "STEM"
|
||||
},
|
||||
"environmental_impact_assessment_engineer": {
|
||||
"name": "环境影响评价工程师",
|
||||
"category": "Other"
|
||||
},
|
||||
"fire_engineer": {
|
||||
"name": "注册消防工程师",
|
||||
"category": "Other"
|
||||
},
|
||||
"high_school_biology": {
|
||||
"name": "高中生物",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_chemistry": {
|
||||
"name": "高中化学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_chinese": {
|
||||
"name": "高中语文",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"high_school_geography": {
|
||||
"name": "高中地理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_history": {
|
||||
"name": "高中历史",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"high_school_mathematics": {
|
||||
"name": "高中数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_physics": {
|
||||
"name": "高中物理",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_politics": {
|
||||
"name": "高中政治",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"ideological_and_moral_cultivation": {
|
||||
"name": "思想道德修养与法律基础",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"law": {
|
||||
"name": "法学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"legal_professional": {
|
||||
"name": "法律职业资格",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"logic": {
|
||||
"name": "逻辑学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"mao_zedong_thought": {
|
||||
"name": "毛泽东思想和中国特色社会主义理论体系概论",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"marxism": {
|
||||
"name": "马克思主义基本原理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"metrology_engineer": {
|
||||
"name": "注册计量师",
|
||||
"category": "STEM"
|
||||
},
|
||||
"middle_school_biology": {
|
||||
"name": "初中生物",
|
||||
"category": "STEM"
|
||||
},
|
||||
"middle_school_chemistry": {
|
||||
"name": "初中化学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"middle_school_geography": {
|
||||
"name": "初中地理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"middle_school_history": {
|
||||
"name": "初中历史",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"middle_school_mathematics": {
|
||||
"name": "初中数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"middle_school_physics": {
|
||||
"name": "初中物理",
|
||||
"category": "STEM"
|
||||
},
|
||||
"middle_school_politics": {
|
||||
"name": "初中政治",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"modern_chinese_history": {
|
||||
"name": "近代史纲要",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"operating_system": {
|
||||
"name": "操作系统",
|
||||
"category": "STEM"
|
||||
},
|
||||
"physician": {
|
||||
"name": "医师资格",
|
||||
"category": "Other"
|
||||
},
|
||||
"plant_protection": {
|
||||
"name": "植物保护",
|
||||
"category": "Other"
|
||||
},
|
||||
"probability_and_statistics": {
|
||||
"name": "概率统计",
|
||||
"category": "STEM"
|
||||
},
|
||||
"professional_tour_guide": {
|
||||
"name": "导游资格",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"sports_science": {
|
||||
"name": "体育学",
|
||||
"category": "Other"
|
||||
},
|
||||
"tax_accountant": {
|
||||
"name": "税务师",
|
||||
"category": "Other"
|
||||
},
|
||||
"teacher_qualification": {
|
||||
"name": "教师资格",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"urban_and_rural_planner": {
|
||||
"name": "注册城乡规划师",
|
||||
"category": "Other"
|
||||
},
|
||||
"veterinary_medicine": {
|
||||
"name": "兽医学",
|
||||
"category": "STEM"
|
||||
}
|
||||
}
|
||||
173
evaluation/evaluate.py
Normal file
173
evaluation/evaluate.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# coding=utf-8
|
||||
# Evaluates the performance of pre-trained models.
|
||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
example: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self,
|
||||
target_data: Dict[str, str],
|
||||
support_set: "Dataset",
|
||||
subject_name: str,
|
||||
use_history: bool
|
||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||
query, resp = self.parse_example(target_data)
|
||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
|
||||
if len(history):
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
query = self.system.format(subject=subject_name) + query
|
||||
|
||||
if not use_history:
|
||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||
history = []
|
||||
return query, resp, history
|
||||
|
||||
|
||||
eval_templates = {
|
||||
"en": EvalTemplate(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: "
|
||||
),
|
||||
"zh": EvalTemplate(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(
|
||||
chat_model: ChatModel,
|
||||
batch_input: Dict[str, torch.Tensor]
|
||||
) -> List[str]:
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
[
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
|
||||
],
|
||||
dim=-1
|
||||
),
|
||||
dim=-1
|
||||
).detach()
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
|
||||
|
||||
|
||||
def evaluate(
|
||||
model_name_or_path: str,
|
||||
finetuning_type: Optional[str] = "lora",
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
template: Optional[str] = "vanilla",
|
||||
task: Optional[str] = "ceval",
|
||||
dataset_dir: Optional[str] = "evaluation",
|
||||
split: Optional[Literal["validation", "test"]] = "validation",
|
||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
||||
n_shot: Optional[int] = 5,
|
||||
batch_size: Optional[int] = 4
|
||||
):
|
||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
||||
categorys = json.load(f)
|
||||
|
||||
chat_model = ChatModel(dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "left"
|
||||
eval_template = eval_templates[lang]
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
"STEM": np.array([], dtype="bool"),
|
||||
"Social Sciences": np.array([], dtype="bool"),
|
||||
"Humanities": np.array([], dtype="bool"),
|
||||
"Other": np.array([], dtype="bool")
|
||||
}
|
||||
overall_corrects = np.array([], dtype="bool")
|
||||
|
||||
pbar = tqdm(categorys.keys())
|
||||
for subject in pbar:
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, labels = [], []
|
||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
||||
for i in range(len(dataset[split])):
|
||||
query, resp, history = eval_template.format_example(
|
||||
target_data=dataset[split][i],
|
||||
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))),
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=chat_model.template.use_history
|
||||
)
|
||||
input_ids, _ = chat_model.template.encode_oneturn(
|
||||
tokenizer=chat_model.tokenizer,
|
||||
query=query,
|
||||
resp=resp,
|
||||
history=history
|
||||
)
|
||||
inputs.append({
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": [1] * len(input_ids)
|
||||
})
|
||||
labels.append(resp)
|
||||
|
||||
outputs = []
|
||||
for i in range(0, len(inputs), batch_size):
|
||||
batch_input = chat_model.tokenizer.pad(
|
||||
inputs[i : i + batch_size],
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt"
|
||||
).to(chat_model.model.device)
|
||||
preds = batch_inference(chat_model, batch_input)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
|
||||
|
||||
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)))
|
||||
for category_name, category_correct in category_corrects.items():
|
||||
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(evaluate)
|
||||
230
evaluation/mmlu/mapping.json
Normal file
230
evaluation/mmlu/mapping.json
Normal file
@@ -0,0 +1,230 @@
|
||||
{
|
||||
"abstract_algebra": {
|
||||
"name": "abstract algebra",
|
||||
"category": "STEM"
|
||||
},
|
||||
"anatomy": {
|
||||
"name": "anatomy",
|
||||
"category": "Other"
|
||||
},
|
||||
"astronomy": {
|
||||
"name": "astronomy",
|
||||
"category": "STEM"
|
||||
},
|
||||
"business_ethics": {
|
||||
"name": "business ethics",
|
||||
"category": "Other"
|
||||
},
|
||||
"clinical_knowledge": {
|
||||
"name": "clinical knowledge",
|
||||
"category": "Other"
|
||||
},
|
||||
"college_biology": {
|
||||
"name": "college biology",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_chemistry": {
|
||||
"name": "college chemistry",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_computer_science": {
|
||||
"name": "college computer science",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_mathematics": {
|
||||
"name": "college mathematics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_medicine": {
|
||||
"name": "college medicine",
|
||||
"category": "Other"
|
||||
},
|
||||
"college_physics": {
|
||||
"name": "college physics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"computer_security": {
|
||||
"name": "computer security",
|
||||
"category": "STEM"
|
||||
},
|
||||
"conceptual_physics": {
|
||||
"name": "conceptual physics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"econometrics": {
|
||||
"name": "econometrics",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"electrical_engineering": {
|
||||
"name": "electrical engineering",
|
||||
"category": "STEM"
|
||||
},
|
||||
"elementary_mathematics": {
|
||||
"name": "elementary mathematics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"formal_logic": {
|
||||
"name": "formal logic",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"global_facts": {
|
||||
"name": "global facts",
|
||||
"category": "Other"
|
||||
},
|
||||
"high_school_biology": {
|
||||
"name": "high school biology",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_chemistry": {
|
||||
"name": "high school chemistry",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_computer_science": {
|
||||
"name": "high school computer science",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_european_history": {
|
||||
"name": "high school european history",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"high_school_geography": {
|
||||
"name": "high school geography",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_government_and_politics": {
|
||||
"name": "high school government and politics",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_macroeconomics": {
|
||||
"name": "high school macroeconomics",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_mathematics": {
|
||||
"name": "high school mathematics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_microeconomics": {
|
||||
"name": "high school microeconomics",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_physics": {
|
||||
"name": "high school physics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_psychology": {
|
||||
"name": "high school psychology",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_statistics": {
|
||||
"name": "high school statistics",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_us_history": {
|
||||
"name": "high school us history",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"high_school_world_history": {
|
||||
"name": "high school world history",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"human_aging": {
|
||||
"name": "human aging",
|
||||
"category": "Other"
|
||||
},
|
||||
"human_sexuality": {
|
||||
"name": "human sexuality",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"international_law": {
|
||||
"name": "international law",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"jurisprudence": {
|
||||
"name": "jurisprudence",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"logical_fallacies": {
|
||||
"name": "logical fallacies",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"machine_learning": {
|
||||
"name": "machine learning",
|
||||
"category": "STEM"
|
||||
},
|
||||
"management": {
|
||||
"name": "management",
|
||||
"category": "Other"
|
||||
},
|
||||
"marketing": {
|
||||
"name": "marketing",
|
||||
"category": "Other"
|
||||
},
|
||||
"medical_genetics": {
|
||||
"name": "medical genetics",
|
||||
"category": "Other"
|
||||
},
|
||||
"miscellaneous": {
|
||||
"name": "miscellaneous",
|
||||
"category": "Other"
|
||||
},
|
||||
"moral_disputes": {
|
||||
"name": "moral disputes",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"moral_scenarios": {
|
||||
"name": "moral scenarios",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"nutrition": {
|
||||
"name": "nutrition",
|
||||
"category": "Other"
|
||||
},
|
||||
"philosophy": {
|
||||
"name": "philosophy",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"prehistory": {
|
||||
"name": "prehistory",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"professional_accounting": {
|
||||
"name": "professional accounting",
|
||||
"category": "Other"
|
||||
},
|
||||
"professional_law": {
|
||||
"name": "professional law",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"professional_medicine": {
|
||||
"name": "professional medicine",
|
||||
"category": "Other"
|
||||
},
|
||||
"professional_psychology": {
|
||||
"name": "professional psychology",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"public_relations": {
|
||||
"name": "public relations",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"security_studies": {
|
||||
"name": "security studies",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"sociology": {
|
||||
"name": "sociology",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"us_foreign_policy": {
|
||||
"name": "us foreign policy",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"virology": {
|
||||
"name": "virology",
|
||||
"category": "Other"
|
||||
},
|
||||
"world_religions": {
|
||||
"name": "world religions",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
||||
167
evaluation/mmlu/mmlu.py
Normal file
167
evaluation/mmlu/mmlu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{hendryckstest2021,
|
||||
title={Measuring Massive Multitask Language Understanding},
|
||||
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
||||
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
|
||||
year={2021}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/hendrycks/test"
|
||||
|
||||
_LICENSE = "MIT"
|
||||
|
||||
_URL = "mmlu.zip"
|
||||
|
||||
task_list = [
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"medical_genetics",
|
||||
"high_school_us_history",
|
||||
"high_school_physics",
|
||||
"high_school_world_history",
|
||||
"virology",
|
||||
"high_school_microeconomics",
|
||||
"econometrics",
|
||||
"college_computer_science",
|
||||
"high_school_biology",
|
||||
"abstract_algebra",
|
||||
"professional_accounting",
|
||||
"philosophy",
|
||||
"professional_medicine",
|
||||
"nutrition",
|
||||
"global_facts",
|
||||
"machine_learning",
|
||||
"security_studies",
|
||||
"public_relations",
|
||||
"professional_psychology",
|
||||
"prehistory",
|
||||
"anatomy",
|
||||
"human_sexuality",
|
||||
"college_medicine",
|
||||
"high_school_government_and_politics",
|
||||
"college_chemistry",
|
||||
"logical_fallacies",
|
||||
"high_school_geography",
|
||||
"elementary_mathematics",
|
||||
"human_aging",
|
||||
"college_mathematics",
|
||||
"high_school_psychology",
|
||||
"formal_logic",
|
||||
"high_school_statistics",
|
||||
"international_law",
|
||||
"high_school_mathematics",
|
||||
"high_school_computer_science",
|
||||
"conceptual_physics",
|
||||
"miscellaneous",
|
||||
"high_school_chemistry",
|
||||
"marketing",
|
||||
"professional_law",
|
||||
"management",
|
||||
"college_physics",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"high_school_macroeconomics",
|
||||
"computer_security",
|
||||
"moral_scenarios",
|
||||
"moral_disputes",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"college_biology",
|
||||
]
|
||||
|
||||
|
||||
class MMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class MMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
MMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
BIN
evaluation/mmlu/mmlu.zip
Normal file
BIN
evaluation/mmlu/mmlu.zip
Normal file
Binary file not shown.
Reference in New Issue
Block a user