mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix generation in seq2seq.py
Former-commit-id: 117594802921177272032e58eba7012ae4805b99
This commit is contained in:
parent
8f1d99c926
commit
c145ca4ad6
@ -25,7 +25,10 @@ def main():
|
|||||||
dataset = prepare_data(model_args, data_args)
|
dataset = prepare_data(model_args, data_args)
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
data_collator = DynamicDataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)
|
||||||
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args.generation_max_length = training_args.generation_max_length if \
|
training_args.generation_max_length = training_args.generation_max_length if \
|
||||||
|
@ -23,8 +23,6 @@ logger = get_logger(__name__)
|
|||||||
class ComputeMetrics:
|
class ComputeMetrics:
|
||||||
r"""
|
r"""
|
||||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||||
|
|
||||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizer
|
tokenizer: PreTrainedTokenizer
|
||||||
@ -34,15 +32,18 @@ class ComputeMetrics:
|
|||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
|
|
||||||
if isinstance(preds, tuple):
|
if isinstance(preds, tuple):
|
||||||
preds = preds[0]
|
preds = preds[0]
|
||||||
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
|
|
||||||
|
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them.
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
preds = preds[:, labels.shape[1]:] # remove prompts
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
for pred, label in zip(preds, labels):
|
for pred, label in zip(preds, labels):
|
||||||
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
|
|
||||||
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
||||||
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
|
preds = preds[:, labels.shape[1]:] # remove prompts
|
||||||
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
|
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
|
||||||
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
|
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user