From ca40e42b3c7e91b3070c195f2d09ea73c16f3af3 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 15 Jul 2024 23:09:50 +0800 Subject: [PATCH] tiny fix Former-commit-id: bda302fbfbdb114dee7782d405732600d2d73279 --- src/llamafactory/eval/evaluator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index c5661997..f05e01a1 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -73,11 +73,11 @@ class Evaluator: return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)] def eval(self) -> None: - task = self.eval_args.task.split("_")[0] - split = self.eval_args.task.split("_")[1] + eval_task = self.eval_args.task.split("_")[0] + eval_split = self.eval_args.task.split("_")[1] mapping = cached_file( - path_or_repo_id=os.path.join(self.eval_args.task_dir, task), + path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task), filename="mapping.json", cache_dir=self.model_args.cache_dir, token=self.model_args.hf_hub_token, @@ -91,7 +91,7 @@ class Evaluator: results = {} for subject in pbar: dataset = load_dataset( - path=os.path.join(self.eval_args.task_dir, task), + path=os.path.join(self.eval_args.task_dir, eval_task), name=subject, cache_dir=self.model_args.cache_dir, download_mode=self.eval_args.download_mode, @@ -100,12 +100,12 @@ class Evaluator: ) pbar.set_postfix_str(categorys[subject]["name"]) inputs, outputs, labels = [], [], [] - for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False): + for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False): support_set = ( dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"])))) ) messages = self.eval_template.format_example( - target_data=dataset[split][i], + target_data=dataset[eval_split][i], support_set=support_set, subject_name=categorys[subject]["name"], )