add eval acc

This commit is contained in:
hiyouga
2024-07-01 03:51:20 +08:00
parent fc2c15d713
commit 1856a08e87
3 changed files with 31 additions and 17 deletions

View File

@@ -135,21 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
for i in range(len(preds)):
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
if len(pad_len):
preds[i] = np.concatenate(
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last
if len(pad_len): # move pad token to last
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.tokenizer.batch_decode(
dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res))