fix seq2seq predictions

Former-commit-id: 65e9ce2cdda3924170774f7b3c6e3fbdbdd87b7f
This commit is contained in:
hiyouga 2023-07-04 22:56:51 +08:00
parent ce622cd93e
commit b9d56b2ac5
2 changed files with 32 additions and 12 deletions

View File

@ -89,9 +89,9 @@ huggingface-cli login
And **powerful GPUs**!
If you want to enable LoRA(QLoRA) or Freeze quantization on Windows, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```

View File

@ -1,8 +1,10 @@
import os
import json
import torch
import numpy as np
import torch.nn as nn
from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.tokenization_utils import PreTrainedTokenizer
@ -34,11 +36,10 @@ class ComputeMetrics:
preds, labels = eval_preds
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
for pred, label in zip(preds, labels):
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
for pred, label in zip(preds, labels):
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
@ -63,6 +64,25 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
r"""
Removes the prompt part in the generated tokens.
Subclass and override to inject custom behavior.
"""
input_ids = inputs["input_ids"]
loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
return (loss, generated_tokens, labels)
def save_predictions(
self,
predict_results: PredictionOutput
@ -77,13 +97,13 @@ class Seq2SeqPeftTrainer(PeftTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for pred, label in zip(predict_results.predictions, predict_results.label_ids):
pred_pad_len, label_pad_len = np.sum(pred == IGNORE_INDEX), np.sum(label == IGNORE_INDEX)
pred = pred[len(label) - label_pad_len : len(pred) - pred_pad_len] # remove prompts
label = label[:len(label) - label_pad_len]
for pred, label in zip(preds, labels):
pred = self.tokenizer.decode(pred, skip_special_tokens=True)
label = self.tokenizer.decode(label, skip_special_tokens=True)