diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index bdfc210b..6c8c6174 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -1,6 +1,6 @@ # coding=utf-8 -# Calculates the ppl of pre-trained models. -# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 +# Calculates the ppl on the dataset of the pre-trained models. +# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json import json from typing import Dict @@ -19,6 +19,7 @@ from llmtuner.model import load_model, load_tokenizer def cal_ppl( model_name_or_path: str, + save_name: str, batch_size: int = 4, stage: str = "sft", dataset: str = "alpaca_en", @@ -69,10 +70,10 @@ def cal_ppl( sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) perplexities.extend(sentence_logps.exp().tolist()) - with open("ppl.json", "w", encoding="utf-8") as f: + with open(save_name, "w", encoding="utf-8") as f: json.dump(perplexities, f, indent=2) - print("Perplexities have been saved at ppl.json.") + print("Perplexities have been saved at {}.".format(save_name)) if __name__ == "__main__":