mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
[trainer] fix KeyError at end of pretrain (#8099)
This commit is contained in:
parent
89a0d10c18
commit
16e26236eb
@ -77,12 +77,22 @@ def run_pt(
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
for key in dataset_module["eval_dataset"].keys():
|
||||||
|
try:
|
||||||
|
perplexity = math.exp(metrics[f"eval_{key}_loss"])
|
||||||
|
except OverflowError:
|
||||||
|
perplexity = float("inf")
|
||||||
|
metrics[f"eval_{key}_perplexity"] = perplexity
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
perplexity = float("inf")
|
perplexity = float("inf")
|
||||||
|
metrics["eval_perplexity"] = perplexity
|
||||||
|
|
||||||
|
|
||||||
metrics["perplexity"] = perplexity
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user