mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	update evaluator
Former-commit-id: bb8661e62481ff7027b8969f3d8a6a17290c9da3
This commit is contained in:
		
							parent
							
								
									784088db3f
								
							
						
					
					
						commit
						4b2b92fd9a
					
				@ -26,9 +26,7 @@ class Evaluator:
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
 | 
			
		||||
        self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
 | 
			
		||||
        self.eval_template = get_eval_template(self.eval_args.lang)
 | 
			
		||||
        self.choice_inputs = [
 | 
			
		||||
            self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
 | 
			
		||||
        ]
 | 
			
		||||
        self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,6 @@ class EvalTemplate:
 | 
			
		||||
    system: str
 | 
			
		||||
    choice: str
 | 
			
		||||
    answer: str
 | 
			
		||||
    prefix: str
 | 
			
		||||
 | 
			
		||||
    def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
 | 
			
		||||
        r"""
 | 
			
		||||
@ -42,8 +41,8 @@ class EvalTemplate:
 | 
			
		||||
eval_templates: Dict[str, "EvalTemplate"] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
 | 
			
		||||
    eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
 | 
			
		||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
 | 
			
		||||
    eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_eval_template(name: str) -> "EvalTemplate":
 | 
			
		||||
@ -56,8 +55,7 @@ _register_eval_template(
 | 
			
		||||
    name="en",
 | 
			
		||||
    system="The following are multiple choice questions (with answers) about {subject}.\n\n",
 | 
			
		||||
    choice="\n{choice}. {content}",
 | 
			
		||||
    answer="\nAnswer: ",
 | 
			
		||||
    prefix=" ",
 | 
			
		||||
    answer="\nAnswer:",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -66,5 +64,4 @@ _register_eval_template(
 | 
			
		||||
    system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
 | 
			
		||||
    choice="\n{choice}. {content}",
 | 
			
		||||
    answer="\n答案:",
 | 
			
		||||
    prefix=" ",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										77
									
								
								tests/eval/test_eval_template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								tests/eval/test_eval_template.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,77 @@
 | 
			
		||||
from llamafactory.eval.template import get_eval_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_eval_template_en():
 | 
			
		||||
    support_set = [
 | 
			
		||||
        {
 | 
			
		||||
            "question": "Fewshot question",
 | 
			
		||||
            "A": "Fewshot1",
 | 
			
		||||
            "B": "Fewshot2",
 | 
			
		||||
            "C": "Fewshot3",
 | 
			
		||||
            "D": "Fewshot4",
 | 
			
		||||
            "answer": "B",
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
    example = {
 | 
			
		||||
        "question": "Target question",
 | 
			
		||||
        "A": "Target1",
 | 
			
		||||
        "B": "Target2",
 | 
			
		||||
        "C": "Target3",
 | 
			
		||||
        "D": "Target4",
 | 
			
		||||
        "answer": "C",
 | 
			
		||||
    }
 | 
			
		||||
    template = get_eval_template(name="en")
 | 
			
		||||
    messages = template.format_example(example, support_set=support_set, subject_name="SubName")
 | 
			
		||||
    assert messages == [
 | 
			
		||||
        {
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "content": (
 | 
			
		||||
                "The following are multiple choice questions (with answers) about SubName.\n\n"
 | 
			
		||||
                "Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        {"role": "assistant", "content": "B"},
 | 
			
		||||
        {
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
 | 
			
		||||
        },
 | 
			
		||||
        {"role": "assistant", "content": "C"},
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_eval_template_zh():
 | 
			
		||||
    support_set = [
 | 
			
		||||
        {
 | 
			
		||||
            "question": "示例问题",
 | 
			
		||||
            "A": "示例答案1",
 | 
			
		||||
            "B": "示例答案2",
 | 
			
		||||
            "C": "示例答案3",
 | 
			
		||||
            "D": "示例答案4",
 | 
			
		||||
            "answer": "B",
 | 
			
		||||
        }
 | 
			
		||||
    ]
 | 
			
		||||
    example = {
 | 
			
		||||
        "question": "目标问题",
 | 
			
		||||
        "A": "目标答案1",
 | 
			
		||||
        "B": "目标答案2",
 | 
			
		||||
        "C": "目标答案3",
 | 
			
		||||
        "D": "目标答案4",
 | 
			
		||||
        "answer": "C",
 | 
			
		||||
    }
 | 
			
		||||
    template = get_eval_template(name="zh")
 | 
			
		||||
    messages = template.format_example(example, support_set=support_set, subject_name="主题")
 | 
			
		||||
    assert messages == [
 | 
			
		||||
        {
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "content": (
 | 
			
		||||
                "以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
 | 
			
		||||
                "示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        {"role": "assistant", "content": "B"},
 | 
			
		||||
        {
 | 
			
		||||
            "role": "user",
 | 
			
		||||
            "content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
 | 
			
		||||
        },
 | 
			
		||||
        {"role": "assistant", "content": "C"},
 | 
			
		||||
    ]
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user