mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
add eval acc
Former-commit-id: 1856a08e87b150fa4bffcb0af703ed84d848e24b
This commit is contained in:
parent
a475d808f2
commit
54e786346e
@ -17,9 +17,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import EvalPrediction
|
||||||
from transformers.utils import is_jieba_available, is_nltk_available
|
from transformers.utils import is_jieba_available, is_nltk_available
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@ -42,6 +44,22 @@ if is_rouge_available():
|
|||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
accuracies = []
|
||||||
|
for i in range(len(preds)):
|
||||||
|
pred, label = preds[i, 1:], labels[i, :-1]
|
||||||
|
label_mask = label != IGNORE_INDEX
|
||||||
|
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))
|
||||||
|
|
||||||
|
return {"accuracy": float(np.mean(accuracies))}
|
||||||
|
|
||||||
|
|
||||||
|
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||||
|
logits = logits[0] if isinstance(logits, (list, tuple)) else logits
|
||||||
|
return torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComputeMetrics:
|
class ComputeMetrics:
|
||||||
r"""
|
r"""
|
||||||
@ -50,11 +68,11 @@ class ComputeMetrics:
|
|||||||
|
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
|
||||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||||
r"""
|
r"""
|
||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
|
@ -135,21 +135,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
for i in range(len(preds)):
|
for i in range(len(preds)):
|
||||||
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
|
||||||
if len(pad_len):
|
if len(pad_len): # move pad token to last
|
||||||
preds[i] = np.concatenate(
|
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||||
(preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1
|
|
||||||
) # move pad token to last
|
|
||||||
|
|
||||||
decoded_inputs = self.tokenizer.batch_decode(
|
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||||
dataset["input_ids"], skip_special_tokens=True, clean_up_tokenization_spaces=False
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
)
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
decoded_labels = self.tokenizer.batch_decode(
|
|
||||||
labels, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
res: List[str] = []
|
res: List[str] = []
|
||||||
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
||||||
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
||||||
|
|
||||||
writer.write("\n".join(res))
|
writer.write("\n".join(res))
|
||||||
|
@ -25,7 +25,7 @@ from ...extras.misc import get_logits_processor
|
|||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
from .metric import ComputeMetrics
|
from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor
|
||||||
from .trainer import CustomSeq2SeqTrainer
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +72,8 @@ def run_sft(
|
|||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||||
|
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
@ -91,7 +92,7 @@ def run_sft(
|
|||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user