diff --git a/README.md b/README.md index 5f9465e3..bbcb3b6f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ ## Changelog -[23/09/23] We integrated MMLU and C-Eval benchmarks in this repo. See [this example](#evaluation-mmlu--c-eval) to evaluate your models. +[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. [23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. @@ -371,7 +371,8 @@ python src/export_model.py \ --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ - --output_dir path_to_export + --output_dir path_to_export \ + --fp16 ``` ### API Demo @@ -407,7 +408,22 @@ python src/web_demo.py \ --checkpoint_dir path_to_checkpoint ``` -### Evaluation and Predict (BLEU & ROUGE_CHINESE) +### Evaluation + +```bash +CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ + --model_name_or_path path_to_llama_model \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --template vanilla \ + --task mmlu \ + --split test \ + --lang en \ + --n_shot 5 \ + --batch_size 4 +``` + +### Predict ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -425,22 +441,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ ``` > [!NOTE] -> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. - -### Evaluation (MMLU & C-Eval) - -```bash -CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ - --model_name_or_path path_to_llama_model \ - --finetuning_type lora \ - --checkpoint_dir path_to_checkpoint \ - --template vanilla \ - --task mmlu \ - --split test \ - --lang en \ - --n_shot 5 \ - --batch_size 4 -``` +> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict. ## License diff --git a/README_zh.md b/README_zh.md index bdabd6d7..e9f0ccd7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,7 +14,7 @@ ## 更新日志 -[23/09/23] 我们在项目中集成了 MMLU 和 C-Eval 评估集。使用方法请参阅[此示例](#模型评估mmlu-和-c-eval)。 +[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 [23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。 @@ -370,7 +370,8 @@ python src/export_model.py \ --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ - --output_dir path_to_export + --output_dir path_to_export \ + --fp16 ``` ### API 服务 @@ -406,7 +407,22 @@ python src/web_demo.py \ --checkpoint_dir path_to_checkpoint ``` -### 指标评估与模型预测(BLEU 分数和汉语 ROUGE 分数) +### 模型评估 + +```bash +CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ + --model_name_or_path path_to_llama_model \ + --finetuning_type lora \ + --checkpoint_dir path_to_checkpoint \ + --template vanilla \ + --task ceval \ + --split validation \ + --lang zh \ + --n_shot 5 \ + --batch_size 4 +``` + +### 模型预测 ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ @@ -424,22 +440,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ ``` > [!NOTE] -> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。 - -### 模型评估(MMLU 和 C-Eval) - -```bash -CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \ - --model_name_or_path path_to_llama_model \ - --finetuning_type lora \ - --checkpoint_dir path_to_checkpoint \ - --template vanilla \ - --task ceval \ - --split validation \ - --lang zh \ - --n_shot 5 \ - --batch_size 4 -``` +> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。 ## 协议 diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py index 44cd531c..33005de3 100644 --- a/evaluation/ceval/ceval.py +++ b/evaluation/ceval/ceval.py @@ -92,14 +92,14 @@ task_list = [ ] -class CevalExamConfig(datasets.BuilderConfig): +class CevalConfig(datasets.BuilderConfig): def __init__(self, **kwargs): super().__init__(version=datasets.Version("1.0.0"), **kwargs) -class CevalExam(datasets.GeneratorBasedBuilder): +class Ceval(datasets.GeneratorBasedBuilder): BUILDER_CONFIGS = [ - CevalExamConfig( + CevalConfig( name=task_name, ) for task_name in task_list diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py new file mode 100644 index 00000000..e89164fd --- /dev/null +++ b/evaluation/cmmlu/cmmlu.py @@ -0,0 +1,163 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import datasets +import pandas as pd + + +_CITATION = """\ +@article{li2023cmmlu, + title={CMMLU: Measuring massive multitask language understanding in Chinese}, + author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin}, + journal={arXiv preprint arXiv:2306.09212}, + year={2023} +} +""" + +_DESCRIPTION = """\ +CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context. +""" + +_HOMEPAGE = "https://github.com/haonan-li/CMMLU" + +_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License" + +_URL = "cmmlu.zip" + +task_list = [ + 'agronomy', + 'anatomy', + 'ancient_chinese', + 'arts', + 'astronomy', + 'business_ethics', + 'chinese_civil_service_exam', + 'chinese_driving_rule', + 'chinese_food_culture', + 'chinese_foreign_policy', + 'chinese_history', + 'chinese_literature', + 'chinese_teacher_qualification', + 'clinical_knowledge', + 'college_actuarial_science', + 'college_education', + 'college_engineering_hydrology', + 'college_law', + 'college_mathematics', + 'college_medical_statistics', + 'college_medicine', + 'computer_science', + 'computer_security', + 'conceptual_physics', + 'construction_project_management', + 'economics', + 'education', + 'electrical_engineering', + 'elementary_chinese', + 'elementary_commonsense', + 'elementary_information_and_technology', + 'elementary_mathematics', + 'ethnology', + 'food_science', + 'genetics', + 'global_facts', + 'high_school_biology', + 'high_school_chemistry', + 'high_school_geography', + 'high_school_mathematics', + 'high_school_physics', + 'high_school_politics', + 'human_sexuality', + 'international_law', + 'journalism', + 'jurisprudence', + 'legal_and_moral_basis', + 'logical', + 'machine_learning', + 'management', + 'marketing', + 'marxist_theory', + 'modern_chinese', + 'nutrition', + 'philosophy', + 'professional_accounting', + 'professional_law', + 'professional_medicine', + 'professional_psychology', + 'public_relations', + 'security_study', + 'sociology', + 'sports_science', + 'traditional_chinese_medicine', + 'virology', + 'world_history', + 'world_religions', +] + + +class CMMLUConfig(datasets.BuilderConfig): + def __init__(self, **kwargs): + super().__init__(version=datasets.Version("1.0.1"), **kwargs) + + +class CMMLU(datasets.GeneratorBasedBuilder): + BUILDER_CONFIGS = [ + CMMLUConfig( + name=task_name, + ) + for task_name in task_list + ] + + def _info(self): + features = datasets.Features( + { + "question": datasets.Value("string"), + "A": datasets.Value("string"), + "B": datasets.Value("string"), + "C": datasets.Value("string"), + "D": datasets.Value("string"), + "answer": datasets.Value("string"), + } + ) + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=features, + homepage=_HOMEPAGE, + license=_LICENSE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + data_dir = dl_manager.download_and_extract(_URL) + task_name = self.config.name + return [ + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepath": os.path.join(data_dir, f"test/{task_name}.csv"), + }, + ), + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepath": os.path.join(data_dir, f"dev/{task_name}.csv"), + }, + ), + ] + + def _generate_examples(self, filepath): + df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8") + for i, instance in enumerate(df.to_dict(orient="records")): + yield i, instance diff --git a/evaluation/cmmlu/cmmlu.zip b/evaluation/cmmlu/cmmlu.zip new file mode 100644 index 00000000..c6bede1d Binary files /dev/null and b/evaluation/cmmlu/cmmlu.zip differ diff --git a/evaluation/cmmlu/mapping.json b/evaluation/cmmlu/mapping.json new file mode 100644 index 00000000..57329d01 --- /dev/null +++ b/evaluation/cmmlu/mapping.json @@ -0,0 +1,270 @@ +{ + "agronomy": { + "name": "农学", + "category": "Other" + }, + "anatomy": { + "name": "解剖学", + "category": "STEM" + }, + "ancient_chinese": { + "name": "古汉语", + "category": "Social Sciences" + }, + "arts": { + "name": "艺术学", + "category": "Humanities" + }, + "astronomy": { + "name": "天文学", + "category": "STEM" + }, + "business_ethics": { + "name": "商业伦理", + "category": "Social Sciences" + }, + "chinese_civil_service_exam": { + "name": "中国公务员考试", + "category": "Social Sciences" + }, + "chinese_driving_rule": { + "name": "中国驾驶规则", + "category": "Other" + }, + "chinese_food_culture": { + "name": "中国饮食文化", + "category": "Social Sciences" + }, + "chinese_foreign_policy": { + "name": "中国外交政策", + "category": "Social Sciences" + }, + "chinese_history": { + "name": "中国历史", + "category": "Humanities" + }, + "chinese_literature": { + "name": "中国文学", + "category": "Humanities" + }, + "chinese_teacher_qualification": { + "name": "中国教师资格", + "category": "Social Sciences" + }, + "college_actuarial_science": { + "name": "大学精算学", + "category": "STEM" + }, + "college_education": { + "name": "大学教育学", + "category": "Social Sciences" + }, + "college_engineering_hydrology": { + "name": "大学工程水文学", + "category": "STEM" + }, + "college_law": { + "name": "大学法律", + "category": "Humanities" + }, + "college_mathematics": { + "name": "大学数学", + "category": "STEM" + }, + "college_medical_statistics": { + "name": "大学医学统计", + "category": "STEM" + }, + "clinical_knowledge": { + "name": "临床知识", + "category": "Other" + }, + "college_medicine": { + "name": "大学医学", + "category": "Other" + }, + "computer_science": { + "name": "计算机科学", + "category": "STEM" + }, + "computer_security": { + "name": "计算机安全", + "category": "Other" + }, + "conceptual_physics": { + "name": "概念物理学", + "category": "STEM" + }, + "construction_project_management": { + "name": "建设工程管理", + "category": "Other" + }, + "economics": { + "name": "经济学", + "category": "Social Sciences" + }, + "education": { + "name": "教育学", + "category": "Social Sciences" + }, + "elementary_chinese": { + "name": "小学语文", + "category": "Social Sciences" + }, + "elementary_commonsense": { + "name": "小学常识", + "category": "Other" + }, + "elementary_information_and_technology": { + "name": "小学信息技术", + "category": "Other" + }, + "electrical_engineering": { + "name": "电气工程", + "category": "STEM" + }, + "elementary_mathematics": { + "name": "初等数学", + "category": "STEM" + }, + "ethnology": { + "name": "民族学", + "category": "Social Sciences" + }, + "food_science": { + "name": "食品科学", + "category": "Other" + }, + "genetics": { + "name": "遗传学", + "category": "STEM" + }, + "global_facts": { + "name": "全球事实", + "category": "Humanities" + }, + "high_school_biology": { + "name": "高中生物", + "category": "STEM" + }, + "high_school_chemistry": { + "name": "高中化学", + "category": "STEM" + }, + "high_school_geography": { + "name": "高中地理", + "category": "Social Sciences" + }, + "high_school_mathematics": { + "name": "高中数学", + "category": "STEM" + }, + "high_school_physics": { + "name": "高中物理学", + "category": "STEM" + }, + "high_school_politics": { + "name": "高中政治", + "category": "Social Sciences" + }, + "human_sexuality": { + "name": "人类性行为", + "category": "Other" + }, + "international_law": { + "name": "国际法学", + "category": "Humanities" + }, + "journalism": { + "name": "新闻学", + "category": "Social Sciences" + }, + "jurisprudence": { + "name": "法理学", + "category": "Humanities" + }, + "legal_and_moral_basis": { + "name": "法律与道德基础", + "category": "Other" + }, + "logical": { + "name": "逻辑学", + "category": "Humanities" + }, + "machine_learning": { + "name": "机器学习", + "category": "STEM" + }, + "management": { + "name": "管理学", + "category": "Social Sciences" + }, + "marketing": { + "name": "市场营销", + "category": "Social Sciences" + }, + "marxist_theory": { + "name": "马克思主义理论", + "category": "Humanities" + }, + "modern_chinese": { + "name": "现代汉语", + "category": "Social Sciences" + }, + "nutrition": { + "name": "营养学", + "category": "Other" + }, + "philosophy": { + "name": "哲学", + "category": "Humanities" + }, + "professional_accounting": { + "name": "专业会计", + "category": "Social Sciences" + }, + "professional_law": { + "name": "专业法学", + "category": "Humanities" + }, + "professional_medicine": { + "name": "专业医学", + "category": "Other" + }, + "professional_psychology": { + "name": "专业心理学", + "category": "Social Sciences" + }, + "public_relations": { + "name": "公共关系", + "category": "Social Sciences" + }, + "security_study": { + "name": "安全研究", + "category": "Social Sciences" + }, + "sociology": { + "name": "社会学", + "category": "Social Sciences" + }, + "sports_science": { + "name": "体育学", + "category": "Other" + }, + "traditional_chinese_medicine": { + "name": "中医中药", + "category": "Other" + }, + "virology": { + "name": "病毒学", + "category": "STEM" + }, + "world_history": { + "name": "世界历史", + "category": "Humanities" + }, + "world_religions": { + "name": "世界宗教", + "category": "Humanities" + } +} \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py index fecf8d2a..72f27cf3 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -1,16 +1,15 @@ # coding=utf-8 # Evaluates the performance of pre-trained models. # Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla -# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 +# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result # Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py import os import fire import json import torch -import random import numpy as np -from tqdm import tqdm +from tqdm import tqdm, trange from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple from datasets import load_dataset from dataclasses import dataclass @@ -30,6 +29,7 @@ class EvalTemplate: system: str choice: str answer: str + prefix: str def parse_example( self, @@ -49,7 +49,6 @@ class EvalTemplate: history = [self.parse_example(support_set[k]) for k in range(len(support_set))] if len(history): - random.shuffle(history) temp = history.pop(0) history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) else: @@ -65,12 +64,14 @@ eval_templates = { "en": EvalTemplate( system="The following are multiple choice questions (with answers) about {subject}.\n\n", choice="\n{choice}. {content}", - answer="\nAnswer: " + answer="\nAnswer: ", + prefix=" " ), "zh": EvalTemplate( system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", choice="\n{choice}. {content}", - answer="\n答案:" + answer="\n答案:", + prefix="\n" ) } @@ -79,9 +80,8 @@ eval_templates = { def batch_inference( chat_model: ChatModel, batch_input: Dict[str, torch.Tensor], - lang: Literal["zh", "en"] + prefix_char: str ) -> List[str]: - prefix_char = "\n" if lang == "zh" else " " logits = chat_model.model(**batch_input).logits probs = torch.nn.functional.softmax( torch.stack( @@ -108,7 +108,8 @@ def evaluate( split: Optional[Literal["validation", "test"]] = "validation", lang: Optional[Literal["zh", "en"]] = "zh", n_shot: Optional[int] = 5, - batch_size: Optional[int] = 4 + batch_size: Optional[int] = 4, + save_name: Optional[str] = None ): with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: categorys = json.load(f) @@ -119,25 +120,25 @@ def evaluate( checkpoint_dir=checkpoint_dir, template=template )) - chat_model.tokenizer.padding_side = "left" eval_template = eval_templates[lang] + assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted." + category_corrects: Dict[str, np.ndarray] = { - "STEM": np.array([], dtype="bool"), - "Social Sciences": np.array([], dtype="bool"), - "Humanities": np.array([], dtype="bool"), - "Other": np.array([], dtype="bool") + subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"] } overall_corrects = np.array([], dtype="bool") - pbar = tqdm(categorys.keys()) + pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) + results = {} for subject in pbar: pbar.set_postfix_str(categorys[subject]["name"]) inputs, labels = [], [] dataset = load_dataset(os.path.join(dataset_dir, task), subject) for i in range(len(dataset[split])): + support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"])))) query, resp, history = eval_template.format_example( target_data=dataset[split][i], - support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))), + support_set=support_set, subject_name=categorys[subject]["name"], use_history=chat_model.template.use_history ) @@ -154,23 +155,33 @@ def evaluate( labels.append(resp) outputs = [] - for i in range(0, len(inputs), batch_size): + for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False): batch_input = chat_model.tokenizer.pad( inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt" ).to(chat_model.model.device) - preds = batch_inference(chat_model, batch_input, lang) + preds = batch_inference(chat_model, batch_input, eval_template.prefix) outputs += preds corrects = (np.array(outputs) == np.array(labels)) category_name = categorys[subject]["category"] category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) overall_corrects = np.concatenate([overall_corrects, corrects], axis=0) + results[subject] = {str(i): outputs[i] for i in range(len(outputs))} - print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))) + score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)) for category_name, category_correct in category_corrects.items(): - print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct))) + if len(category_correct): + score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) + + print(score_info) + if save_name is not None: + with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f: + json.dump(results, f, indent=2) + + with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f: + f.write(score_info) if __name__ == "__main__":