add averaging in evaluation

Former-commit-id: 5310e4d1829f36619c8f224d09ec15eeaf7a4877
This commit is contained in:
hiyouga 2023-10-10 23:16:31 +08:00
parent 141937ead6
commit c350ba0f05

View File

@ -9,10 +9,11 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm, trange from collections import Counter
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner import ChatModel from llmtuner import ChatModel
@ -86,10 +87,8 @@ def batch_inference(
probs = torch.nn.functional.softmax( probs = torch.nn.functional.softmax(
torch.stack( torch.stack(
[ [
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]], logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]], for choice in choices
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
], ],
dim=-1 dim=-1
), ),
@ -108,11 +107,12 @@ def evaluate(
split: Optional[Literal["validation", "test"]] = "validation", split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh", lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
@ -124,17 +124,17 @@ def evaluate(
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted." assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = { category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"] subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
} }
overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject) dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])): labels, answers, all_outputs = [], [], []
for epoch in range(n_avg):
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
inputs, outputs = [], []
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"])))) support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example( query, resp, history = eval_template.format_example(
target_data=dataset[split][i], target_data=dataset[split][i],
@ -143,37 +143,34 @@ def evaluate(
use_history=chat_model.template.use_history use_history=chat_model.template.use_history
) )
input_ids, _ = chat_model.template.encode_oneturn( input_ids, _ = chat_model.template.encode_oneturn(
tokenizer=chat_model.tokenizer, tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
query=query,
resp=resp,
history=history
) )
inputs.append({ inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
"input_ids": input_ids, if epoch == 0:
"attention_mask": [1] * len(input_ids)
})
labels.append(resp) labels.append(resp)
outputs = [] for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad( batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device) ).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, eval_template.prefix) preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds outputs += preds
all_outputs.append(outputs)
corrects = (np.array(outputs) == np.array(labels)) for i in range(len(all_outputs[0])):
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
answers.append(count.most_common(1)[0][0])
corrects = (np.array(answers) == np.array(labels))
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0) category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))} results[subject] = {str(i): answers[i] for i in range(len(answers))}
score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)) score_info = "\n".join([
for category_name, category_correct in category_corrects.items(): "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
if len(category_correct): for category_name, category_correct in category_corrects.items() if len(category_correct)
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) ])
print(score_info) print(score_info)
if save_name is not None: if save_name is not None: