This commit is contained in:
hiyouga
2023-06-27 23:54:24 +08:00
parent 18f87c1b25
commit 450910c1db
4 changed files with 15 additions and 5 deletions

View File

@@ -35,8 +35,9 @@ class ComputeMetrics:
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels):
pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts
label = label[:len(label) - np.sum(label == IGNORE_INDEX)]
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
@@ -79,8 +80,9 @@ class Seq2SeqPeftTrainer(PeftTrainer):
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(predict_results.predictions, predict_results.label_ids):
pred = pred[len(label) - np.sum(label == IGNORE_INDEX) : len(pred) - np.sum(pred == IGNORE_INDEX)] # remove prompts
label = label[:len(label) - np.sum(label == IGNORE_INDEX)]
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)