diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 95bfcb69..72faef0a 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -17,9 +17,11 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict import numpy as np +import torch +from transformers import EvalPrediction from transformers.utils import is_jieba_available, is_nltk_available from ...extras.constants import IGNORE_INDEX @@ -42,6 +44,22 @@ if is_rouge_available(): from rouge_chinese import Rouge +def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: + preds, labels = eval_preds.predictions, eval_preds.label_ids + accuracies = [] + for i in range(len(preds)): + pred, label = preds[i, 1:], labels[i, :-1] + label_mask = label != IGNORE_INDEX + accuracies.append(np.mean(pred[label_mask] == label[label_mask])) + + return {"accuracy": float(np.mean(accuracies))} + + +def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": + logits = logits[0] if isinstance(logits, (list, tuple)) else logits + return torch.argmax(logits, dim=-1) + + @dataclass class ComputeMetrics: r""" @@ -50,11 +68,11 @@ class ComputeMetrics: tokenizer: "PreTrainedTokenizer" - def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]: r""" Uses the model predictions to compute metrics. """ - preds, labels = eval_preds + preds, labels = eval_preds.predictions, eval_preds.label_ids score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 06bd2b6b..954bb69f 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -135,21 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): for i in range(len(preds)): pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] - if len(pad_len): - preds[i] = np.concatenate( - (preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1 - ) # move pad token to last + if len(pad_len): # move pad token to last + preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) - decoded_inputs = self.tokenizer.batch_decode( - dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - decoded_labels = self.tokenizer.batch_decode( - labels, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) - decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) with open(output_prediction_file, "w", encoding="utf-8") as writer: res: List[str] = [] for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 885bc7ac..0c3f9b11 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -25,7 +25,7 @@ from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push -from .metric import ComputeMetrics +from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor from .trainer import CustomSeq2SeqTrainer @@ -72,7 +72,8 @@ def run_sft( finetuning_args=finetuning_args, data_collator=data_collator, callbacks=callbacks, - compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy, + preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor, **tokenizer_module, **split_dataset(dataset, data_args, training_args), ) @@ -91,7 +92,7 @@ def run_sft( trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: - plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) # Evaluation if training_args.do_eval: