From 4b2b92fd9aecc6e6f40c44d212f5889d9f692446 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 10 Jun 2024 23:56:00 +0800 Subject: [PATCH] update evaluator Former-commit-id: bb8661e62481ff7027b8969f3d8a6a17290c9da3 --- src/llamafactory/eval/evaluator.py | 4 +- src/llamafactory/eval/template.py | 9 ++-- tests/eval/test_eval_template.py | 77 ++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 tests/eval/test_eval_template.py diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index 192f4815..5c6fb104 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -26,9 +26,7 @@ class Evaluator: self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.eval_template = get_eval_template(self.eval_args.lang) - self.choice_inputs = [ - self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES - ] + self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] @torch.inference_mode() def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py index a4a6ef0e..2cbb5aaf 100644 --- a/src/llamafactory/eval/template.py +++ b/src/llamafactory/eval/template.py @@ -10,7 +10,6 @@ class EvalTemplate: system: str choice: str answer: str - prefix: str def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: r""" @@ -42,8 +41,8 @@ class EvalTemplate: eval_templates: Dict[str, "EvalTemplate"] = {} -def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: - eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) +def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None: + eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer) def get_eval_template(name: str) -> "EvalTemplate": @@ -56,8 +55,7 @@ _register_eval_template( name="en", system="The following are multiple choice questions (with answers) about {subject}.\n\n", choice="\n{choice}. {content}", - answer="\nAnswer: ", - prefix=" ", + answer="\nAnswer:", ) @@ -66,5 +64,4 @@ _register_eval_template( system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", choice="\n{choice}. {content}", answer="\n答案:", - prefix=" ", ) diff --git a/tests/eval/test_eval_template.py b/tests/eval/test_eval_template.py new file mode 100644 index 00000000..f6a91a67 --- /dev/null +++ b/tests/eval/test_eval_template.py @@ -0,0 +1,77 @@ +from llamafactory.eval.template import get_eval_template + + +def test_eval_template_en(): + support_set = [ + { + "question": "Fewshot question", + "A": "Fewshot1", + "B": "Fewshot2", + "C": "Fewshot3", + "D": "Fewshot4", + "answer": "B", + } + ] + example = { + "question": "Target question", + "A": "Target1", + "B": "Target2", + "C": "Target3", + "D": "Target4", + "answer": "C", + } + template = get_eval_template(name="en") + messages = template.format_example(example, support_set=support_set, subject_name="SubName") + assert messages == [ + { + "role": "user", + "content": ( + "The following are multiple choice questions (with answers) about SubName.\n\n" + "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:", + }, + {"role": "assistant", "content": "C"}, + ] + + +def test_eval_template_zh(): + support_set = [ + { + "question": "示例问题", + "A": "示例答案1", + "B": "示例答案2", + "C": "示例答案3", + "D": "示例答案4", + "answer": "B", + } + ] + example = { + "question": "目标问题", + "A": "目标答案1", + "B": "目标答案2", + "C": "目标答案3", + "D": "目标答案4", + "answer": "C", + } + template = get_eval_template(name="zh") + messages = template.format_example(example, support_set=support_set, subject_name="主题") + assert messages == [ + { + "role": "user", + "content": ( + "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n" + "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:" + ), + }, + {"role": "assistant", "content": "B"}, + { + "role": "user", + "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:", + }, + {"role": "assistant", "content": "C"}, + ]