diff --git a/src/train_sft.py b/src/train_sft.py index f82d254e..da104fdd 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -25,7 +25,10 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") - data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss) + data_collator = DynamicDataCollatorWithPadding( + tokenizer=tokenizer, + ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate) + ) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length if \ diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 47bbcc1d..90a9810e 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -23,8 +23,6 @@ logger = get_logger(__name__) class ComputeMetrics: r""" Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. - - Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 """ tokenizer: PreTrainedTokenizer @@ -34,15 +32,18 @@ class ComputeMetrics: Uses the model predictions to compute metrics. """ preds, labels = eval_preds + if isinstance(preds, tuple): preds = preds[0] - # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True. + + # Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them. preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + preds = preds[:, labels.shape[1]:] # remove prompts score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + for pred, label in zip(preds, labels): - pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True))) reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True))) @@ -83,7 +84,7 @@ class Seq2SeqPeftTrainer(PeftTrainer): preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) - preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries + preds = preds[:, labels.shape[1]:] # remove prompts preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds] labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]