mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 04:10:36 +08:00
[trainer] fix gen_kwarg to eval during training (#5451)
* Correctly pass gen_kwarg to eval during model runs * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 845d16122496311e08263610a6a922f82604de7b
This commit is contained in:
@@ -78,6 +78,12 @@ def run_sft(
|
||||
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
@@ -85,17 +91,12 @@ def run_sft(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
Reference in New Issue
Block a user