mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[trainer] fix gen_kwarg to eval during training (#5451)
* Correctly pass gen_kwarg to eval during model runs * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 11eac71c13cd432322b69ae74a3b8fa17af31bc4
This commit is contained in:
parent
0ad9f7f058
commit
48173b606c
@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
self,
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
if is_transformers_version_greater_than("4.46"):
|
if is_transformers_version_greater_than("4.46"):
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
@ -58,6 +62,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
if gen_kwargs is not None:
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
|
||||||
|
self._gen_kwargs = gen_kwargs
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
self.add_callback(SaveProcessorCallback(processor))
|
self.add_callback(SaveProcessorCallback(processor))
|
||||||
|
@ -78,6 +78,12 @@ def run_sft(
|
|||||||
metric_module["compute_metrics"] = ComputeAccuracy()
|
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||||
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||||
|
|
||||||
|
# Keyword arguments for `model.generate`
|
||||||
|
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||||
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -85,17 +91,12 @@ def run_sft(
|
|||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
gen_kwargs=gen_kwargs,
|
||||||
**dataset_module,
|
**dataset_module,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**metric_module,
|
**metric_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
|
||||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
|
||||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user