mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
fix #6499
This commit is contained in:
@@ -117,8 +117,6 @@ def run_sft(
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
|
||||
metrics.pop("eval_loss", None)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
@@ -126,8 +124,6 @@ def run_sft(
|
||||
if training_args.do_predict:
|
||||
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||
predict_results.metrics.pop("predict_loss", None)
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||
|
||||
Reference in New Issue
Block a user