mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
fix bleu score
Former-commit-id: 6874dce4444e6e6ce9d6125275dbf3dfdfb4fb22
This commit is contained in:
parent
e4e36a2d74
commit
4b093996a7
@ -39,9 +39,12 @@ class ComputeMetrics:
|
|||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
for pred, label in zip(preds, labels):
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
|
||||||
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
|
hypothesis = list(jieba.cut(pred))
|
||||||
|
reference = list(jieba.cut(label))
|
||||||
|
|
||||||
if len(" ".join(hypothesis).split()) == 0:
|
if len(" ".join(hypothesis).split()) == 0:
|
||||||
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||||
@ -101,12 +104,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
res: List[str] = []
|
res: List[str] = []
|
||||||
for pred, label in zip(preds, labels):
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
|
|
||||||
label = self.tokenizer.decode(label, skip_special_tokens=True)
|
|
||||||
|
|
||||||
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
||||||
|
|
||||||
writer.write("\n".join(res))
|
writer.write("\n".join(res))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user