Former-commit-id: 78589cf90c6e12e612f269b1c771f19f3dad83d2
This commit is contained in:
hiyouga 2024-06-15 04:34:55 +08:00
parent a3f4925c2c
commit ab66ae8cd2
2 changed files with 8 additions and 4 deletions

View File

@ -13,6 +13,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import ProcessorMixin from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
@ -94,7 +95,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
def save_predictions(self, predict_results: "PredictionOutput") -> None: def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
@ -120,6 +121,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1 (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
) # move pad token to last ) # move pad token to last
decoded_inputs = self.tokenizer.batch_decode(
dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decoded_labels = self.tokenizer.batch_decode( decoded_labels = self.tokenizer.batch_decode(
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
) )
@ -127,6 +131,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = [] res: List[str] = []
for label, pred in zip(decoded_labels, decoded_preds): for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
writer.write("\n".join(res)) writer.write("\n".join(res))

View File

@ -93,7 +93,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results) trainer.save_predictions(dataset, predict_results)
# Create model card # Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)