mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[trainer] fix gen_kwarg to eval during training (#5451)
* Correctly pass gen_kwarg to eval during model runs * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 845d16122496311e08263610a6a922f82604de7b
This commit is contained in:
@@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
gen_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
@@ -58,6 +62,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
if gen_kwargs is not None:
|
||||
# https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
Reference in New Issue
Block a user