mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[trainer] fix KeyError at end of pretrain (#8099)
This commit is contained in:
parent
89a0d10c18
commit
16e26236eb
@ -77,12 +77,22 @@ def run_pt(
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics["perplexity"] = perplexity
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
for key in dataset_module["eval_dataset"].keys():
|
||||
try:
|
||||
perplexity = math.exp(metrics[f"eval_{key}_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics[f"eval_{key}_perplexity"] = perplexity
|
||||
else:
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
metrics["eval_perplexity"] = perplexity
|
||||
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user