mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-10 08:00:36 +08:00
[fix] Fix prediction metrics in scripts/vllm_infer.py to match Transformers (#9701)
Co-authored-by: xuht6 <xuht6@asiainfo.com>
This commit is contained in:
@@ -14,11 +14,13 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import fire
|
import fire
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
@@ -27,6 +29,8 @@ from llamafactory.extras.packages import is_vllm_available
|
|||||||
from llamafactory.hparams import get_infer_args
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
|
from eval_bleu_rouge import compute_metrics
|
||||||
|
|
||||||
|
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
@@ -51,6 +55,7 @@ def vllm_infer(
|
|||||||
max_samples: int | None = None,
|
max_samples: int | None = None,
|
||||||
vllm_config: str = "{}",
|
vllm_config: str = "{}",
|
||||||
save_name: str = "generated_predictions.jsonl",
|
save_name: str = "generated_predictions.jsonl",
|
||||||
|
matrix_save_name: str = None,
|
||||||
temperature: float = 0.95,
|
temperature: float = 0.95,
|
||||||
top_p: float = 0.7,
|
top_p: float = 0.7,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
@@ -117,6 +122,7 @@ def vllm_infer(
|
|||||||
if isinstance(model_args.vllm_config, dict):
|
if isinstance(model_args.vllm_config, dict):
|
||||||
engine_args.update(model_args.vllm_config)
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
|
model_preparation_start_time = time.time()
|
||||||
llm = LLM(**engine_args)
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
# load datasets
|
# load datasets
|
||||||
@@ -142,6 +148,7 @@ def vllm_infer(
|
|||||||
all_prompts, all_preds, all_labels = [], [], []
|
all_prompts, all_preds, all_labels = [], [], []
|
||||||
need_video_kwargs = _need_video_kwargs(template)
|
need_video_kwargs = _need_video_kwargs(template)
|
||||||
|
|
||||||
|
model_predict_start_time = time.time()
|
||||||
# Add batch process to avoid the issue of too many files opened
|
# Add batch process to avoid the issue of too many files opened
|
||||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||||
vllm_inputs, prompts, labels = [], [], []
|
vllm_inputs, prompts, labels = [], [], []
|
||||||
@@ -218,6 +225,7 @@ def vllm_infer(
|
|||||||
all_labels.extend(labels)
|
all_labels.extend(labels)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
model_predict_end_time = time.time()
|
||||||
# Write all results at once outside the loop
|
# Write all results at once outside the loop
|
||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
||||||
@@ -227,6 +235,49 @@ def vllm_infer(
|
|||||||
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
|
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
|
||||||
print("*" * 70)
|
print("*" * 70)
|
||||||
|
|
||||||
|
# Write all matrix results when matrix_save_name is not None,
|
||||||
|
# The result matrix is referencing src.llamafactory.train.sft.workflow.run_sft # 127~132
|
||||||
|
# trainer.save_metrics("predict", predict_results.metrics)
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# "predict_bleu-4": 4.349975,
|
||||||
|
# "predict_model_preparation_time": 0.0128,
|
||||||
|
# "predict_rouge-1": 21.873359375,
|
||||||
|
# "predict_rouge-2": 4.144340625,
|
||||||
|
# "predict_rouge-l": 10.83949375,
|
||||||
|
# "predict_runtime": 131.664,
|
||||||
|
# "predict_samples_per_second": 0.076,
|
||||||
|
# "predict_steps_per_second": 0.008
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
if matrix_save_name is not None:
|
||||||
|
predict_time = model_predict_end_time - model_predict_start_time
|
||||||
|
preparation_time = model_predict_start_time - model_preparation_start_time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
dataset = load_dataset("json", data_files=save_name, split="train")
|
||||||
|
dataset = dataset.map(compute_metrics, num_proc=8, remove_columns=dataset.column_names)
|
||||||
|
score_dict = dataset.to_dict()
|
||||||
|
|
||||||
|
average_score = {}
|
||||||
|
for task, scores in sorted(score_dict.items(), key=lambda x: x[0]):
|
||||||
|
score = sum(scores) / len(scores) if scores else 0.0
|
||||||
|
print(f"predict_{task}: {score:.4f}")
|
||||||
|
average_score["predict_" + task] = score
|
||||||
|
|
||||||
|
average_score['predict_model_preparation_time'] = preparation_time
|
||||||
|
average_score['predict_runtime'] = predict_time
|
||||||
|
num_steps = len(range(0, len(train_dataset), batch_size))
|
||||||
|
average_score['predict_samples_per_second'] = len(dataset) / predict_time if predict_time > 0 else 0.0
|
||||||
|
average_score['predict_steps_per_second'] = num_steps / predict_time if predict_time > 0 else 0.0
|
||||||
|
|
||||||
|
with open(matrix_save_name, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(average_score, f, indent=4)
|
||||||
|
|
||||||
|
print("*" * 70)
|
||||||
|
print(f"\nDone in {time.time() - start_time:.3f}s.\nScore file saved to {matrix_save_name}.")
|
||||||
|
print("*" * 70)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(vllm_infer)
|
fire.Fire(vllm_infer)
|
||||||
Reference in New Issue
Block a user