add CMMLU, update eval script

Former-commit-id: 47f31f06a946eefa5a972e4a566cf3ce05e1e111
This commit is contained in:
hiyouga 2023-09-23 21:10:17 +08:00
parent f7cecd20e3
commit 73c48d0463
5 changed files with 237 additions and 61 deletions

View File

@ -14,7 +14,7 @@
## Changelog ## Changelog
[23/09/23] We integrated MMLU and C-Eval benchmarks in this repo. See [this example](#evaluation-mmlu--c-eval) to evaluate your models. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
@ -371,7 +371,8 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export \
--fp16
``` ```
### API Demo ### API Demo
@ -407,7 +408,22 @@ python src/web_demo.py \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### Evaluation and Predict (BLEU & ROUGE_CHINESE) ### Evaluation
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
### Predict
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -425,22 +441,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
``` ```
> [!NOTE] > [!NOTE]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. > We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
### Evaluation (MMLU & C-Eval)
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
## License ## License

View File

@ -14,7 +14,7 @@
## 更新日志 ## 更新日志
[23/09/23] 我们在项目中集成了 MMLU 和 C-Eval 评估集。使用方法请参阅[此示例](#模型评估mmlu-和-c-eval)。 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能 [23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能
@ -370,7 +370,8 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export \
--fp16
``` ```
### API 服务 ### API 服务
@ -406,7 +407,22 @@ python src/web_demo.py \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### 指标评估与模型预测BLEU 分数和汉语 ROUGE 分数) ### 模型评估
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
### 模型预测
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -424,22 +440,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
``` ```
> [!NOTE] > [!NOTE]
> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` > 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型评估MMLU 和 C-Eval
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
## 协议 ## 协议

View File

@ -92,14 +92,14 @@ task_list = [
] ]
class CevalExamConfig(datasets.BuilderConfig): class CevalConfig(datasets.BuilderConfig):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs) super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class CevalExam(datasets.GeneratorBasedBuilder): class Ceval(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
CevalExamConfig( CevalConfig(
name=task_name, name=task_name,
) )
for task_name in task_list for task_name in task_list

163
evaluation/cmmlu/cmmlu.py Normal file
View File

@ -0,0 +1,163 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
journal={arXiv preprint arXiv:2306.09212},
year={2023}
}
"""
_DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
"""
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "cmmlu.zip"
task_list = [
'agronomy',
'anatomy',
'ancient_chinese',
'arts',
'astronomy',
'business_ethics',
'chinese_civil_service_exam',
'chinese_driving_rule',
'chinese_food_culture',
'chinese_foreign_policy',
'chinese_history',
'chinese_literature',
'chinese_teacher_qualification',
'clinical_knowledge',
'college_actuarial_science',
'college_education',
'college_engineering_hydrology',
'college_law',
'college_mathematics',
'college_medical_statistics',
'college_medicine',
'computer_science',
'computer_security',
'conceptual_physics',
'construction_project_management',
'economics',
'education',
'electrical_engineering',
'elementary_chinese',
'elementary_commonsense',
'elementary_information_and_technology',
'elementary_mathematics',
'ethnology',
'food_science',
'genetics',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_geography',
'high_school_mathematics',
'high_school_physics',
'high_school_politics',
'human_sexuality',
'international_law',
'journalism',
'jurisprudence',
'legal_and_moral_basis',
'logical',
'machine_learning',
'management',
'marketing',
'marxist_theory',
'modern_chinese',
'nutrition',
'philosophy',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_study',
'sociology',
'sports_science',
'traditional_chinese_medicine',
'virology',
'world_history',
'world_religions',
]
class CMMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
class CMMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CMMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
yield i, instance

View File

@ -1,16 +1,15 @@
# coding=utf-8 # coding=utf-8
# Evaluates the performance of pre-trained models. # Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla # Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 # --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os import os
import fire import fire
import json import json
import torch import torch
import random
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@ -30,6 +29,7 @@ class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str
def parse_example( def parse_example(
self, self,
@ -49,7 +49,6 @@ class EvalTemplate:
history = [self.parse_example(support_set[k]) for k in range(len(support_set))] history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history): if len(history):
random.shuffle(history)
temp = history.pop(0) temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else: else:
@ -65,12 +64,14 @@ eval_templates = {
"en": EvalTemplate( "en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: " answer="\nAnswer: ",
prefix=" "
), ),
"zh": EvalTemplate( "zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:" answer="\n答案:",
prefix="\n"
) )
} }
@ -79,9 +80,8 @@ eval_templates = {
def batch_inference( def batch_inference(
chat_model: ChatModel, chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor], batch_input: Dict[str, torch.Tensor],
lang: Literal["zh", "en"] prefix_char: str
) -> List[str]: ) -> List[str]:
prefix_char = "\n" if lang == "zh" else " "
logits = chat_model.model(**batch_input).logits logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax( probs = torch.nn.functional.softmax(
torch.stack( torch.stack(
@ -108,7 +108,8 @@ def evaluate(
split: Optional[Literal["validation", "test"]] = "validation", split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh", lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
batch_size: Optional[int] = 4 batch_size: Optional[int] = 4,
save_name: Optional[str] = None
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f) categorys = json.load(f)
@ -119,25 +120,25 @@ def evaluate(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
template=template template=template
)) ))
chat_model.tokenizer.padding_side = "left"
eval_template = eval_templates[lang] eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = { category_corrects: Dict[str, np.ndarray] = {
"STEM": np.array([], dtype="bool"), subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"]
"Social Sciences": np.array([], dtype="bool"),
"Humanities": np.array([], dtype="bool"),
"Other": np.array([], dtype="bool")
} }
overall_corrects = np.array([], dtype="bool") overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys()) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar: for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], [] inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject) dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])): for i in range(len(dataset[split])):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example( query, resp, history = eval_template.format_example(
target_data=dataset[split][i], target_data=dataset[split][i],
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))), support_set=support_set,
subject_name=categorys[subject]["name"], subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history use_history=chat_model.template.use_history
) )
@ -154,23 +155,33 @@ def evaluate(
labels.append(resp) labels.append(resp)
outputs = [] outputs = []
for i in range(0, len(inputs), batch_size): for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad( batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], inputs[i : i + batch_size],
return_attention_mask=True, return_attention_mask=True,
return_tensors="pt" return_tensors="pt"
).to(chat_model.model.device) ).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, lang) preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds outputs += preds
corrects = (np.array(outputs) == np.array(labels)) corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0) overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))) score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))
for category_name, category_correct in category_corrects.items(): for category_name, category_correct in category_corrects.items():
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct))) if len(category_correct):
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
print(score_info)
if save_name is not None:
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__": if __name__ == "__main__":