mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
tiny fix
Former-commit-id: bda302fbfbdb114dee7782d405732600d2d73279
This commit is contained in:
parent
9d8e0f0837
commit
ca40e42b3c
@ -73,11 +73,11 @@ class Evaluator:
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
task = self.eval_args.task.split("_")[0]
|
||||
split = self.eval_args.task.split("_")[1]
|
||||
eval_task = self.eval_args.task.split("_")[0]
|
||||
eval_split = self.eval_args.task.split("_")[1]
|
||||
|
||||
mapping = cached_file(
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, task),
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
|
||||
filename="mapping.json",
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.hf_hub_token,
|
||||
@ -91,7 +91,7 @@ class Evaluator:
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
dataset = load_dataset(
|
||||
path=os.path.join(self.eval_args.task_dir, task),
|
||||
path=os.path.join(self.eval_args.task_dir, eval_task),
|
||||
name=subject,
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
download_mode=self.eval_args.download_mode,
|
||||
@ -100,12 +100,12 @@ class Evaluator:
|
||||
)
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, outputs, labels = [], [], []
|
||||
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
|
||||
for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = (
|
||||
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
)
|
||||
messages = self.eval_template.format_example(
|
||||
target_data=dataset[split][i],
|
||||
target_data=dataset[eval_split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user