mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix generation in seq2seq.py
Former-commit-id: 117594802921177272032e58eba7012ae4805b99
This commit is contained in:
parent
8f1d99c926
commit
c145ca4ad6
@ -25,7 +25,10 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
|
@ -23,8 +23,6 @@ logger = get_logger(__name__)
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
|
||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
@ -34,15 +32,18 @@ class ComputeMetrics:
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[0]
|
||||
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
|
||||
|
||||
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them.
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
|
||||
preds = preds[:, labels.shape[1]:] # remove prompts
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
for pred, label in zip(preds, labels):
|
||||
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
|
||||
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
||||
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
||||
|
||||
@ -83,7 +84,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||
|
||||
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
|
||||
preds = preds[:, labels.shape[1]:] # remove prompts
|
||||
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
|
||||
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user