support InternLM

This commit is contained in:
hiyouga
2023-07-07 11:02:28 +08:00
parent caa00d3ac2
commit a2f507c562
3 changed files with 15 additions and 2 deletions

View File

@@ -104,8 +104,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []