mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
1. change the task name format
2. delete split param in data_args.py Former-commit-id: 645211dc01b5d4db3ccd0e3dce03a53860eded26
This commit is contained in:
parent
22859b8734
commit
76046dfda8
@ -6,8 +6,7 @@ adapter_name_or_path: saves/llama3-8b/lora/sft
|
|||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
|
||||||
### dataset
|
### dataset
|
||||||
task: mmlu
|
task: mmlu_test
|
||||||
split: test
|
|
||||||
template: fewshot
|
template: fewshot
|
||||||
lang: en
|
lang: en
|
||||||
n_shot: 5
|
n_shot: 5
|
||||||
|
@ -73,8 +73,11 @@ class Evaluator:
|
|||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
|
task = self.eval_args.task.split("_")[0]
|
||||||
|
split = self.eval_args.task.split("_")[1]
|
||||||
|
|
||||||
mapping = cached_file(
|
mapping = cached_file(
|
||||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path_or_repo_id=os.path.join(self.eval_args.task_dir, task),
|
||||||
filename="mapping.json",
|
filename="mapping.json",
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
token=self.model_args.hf_hub_token,
|
token=self.model_args.hf_hub_token,
|
||||||
@ -88,7 +91,7 @@ class Evaluator:
|
|||||||
results = {}
|
results = {}
|
||||||
for subject in pbar:
|
for subject in pbar:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path=os.path.join(self.eval_args.task_dir, task),
|
||||||
name=subject,
|
name=subject,
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
download_mode=self.eval_args.download_mode,
|
download_mode=self.eval_args.download_mode,
|
||||||
@ -97,12 +100,12 @@ class Evaluator:
|
|||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
|
||||||
support_set = (
|
support_set = (
|
||||||
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||||
)
|
)
|
||||||
messages = self.eval_template.format_example(
|
messages = self.eval_template.format_example(
|
||||||
target_data=dataset[self.data_args.split][i],
|
target_data=dataset[split][i],
|
||||||
support_set=support_set,
|
support_set=support_set,
|
||||||
subject_name=categorys[subject]["name"],
|
subject_name=categorys[subject]["name"],
|
||||||
)
|
)
|
||||||
|
@ -41,10 +41,6 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
split: str = field(
|
|
||||||
default="train",
|
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
|
||||||
)
|
|
||||||
cutoff_len: int = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user