fix generation in seq2seq.py

Former-commit-id: 117594802921177272032e58eba7012ae4805b99
This commit is contained in:
hiyouga 2023-06-26 18:07:06 +08:00
parent 8f1d99c926
commit c145ca4ad6
2 changed files with 10 additions and 6 deletions

View File

@ -25,7 +25,10 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
data_collator = DynamicDataCollatorWithPadding(
tokenizer=tokenizer,
ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)
)
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \

View File

@ -23,8 +23,6 @@ logger = get_logger(__name__)
class ComputeMetrics:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
"""
tokenizer: PreTrainedTokenizer
@ -34,15 +32,18 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them.
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
preds = preds[:, labels.shape[1]:] # remove prompts
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels):
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
@ -83,7 +84,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
preds = preds[:, labels.shape[1]:] # remove prompts
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]