diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 29ea8425..ad785253 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import json from typing import Optional import fire from transformers import Seq2SeqTrainingArguments +from tqdm import tqdm from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX @@ -53,6 +55,7 @@ def vllm_infer( image_min_pixels: int = 32 * 32, video_fps: float = 2.0, video_maxlen: int = 128, + batch_size: int = 1024, ): r"""Perform batch generation using vLLM engine, which supports tensor parallelism. @@ -85,57 +88,6 @@ def vllm_infer( tokenizer = tokenizer_module["tokenizer"] template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate - dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module) - - inputs, prompts, labels = [], [], [] - for sample in dataset_module["train_dataset"]: - if sample["images"]: - multi_modal_data = { - "image": template_obj.mm_plugin._regularize_images( - sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels - )["images"] - } - elif sample["videos"]: - multi_modal_data = { - "video": template_obj.mm_plugin._regularize_videos( - sample["videos"], - image_max_pixels=image_max_pixels, - image_min_pixels=image_min_pixels, - video_fps=video_fps, - video_maxlen=video_maxlen, - )["videos"] - } - elif sample["audios"]: - audio_data = template_obj.mm_plugin._regularize_audios( - sample["audios"], - sampling_rate=16000, - ) - multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} - else: - multi_modal_data = None - - inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data}) - prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens)) - labels.append( - tokenizer.decode( - list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens - ) - ) - - sampling_params = SamplingParams( - repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 - temperature=generating_args.temperature, - top_p=generating_args.top_p or 1.0, # top_p must > 0 - top_k=generating_args.top_k or -1, # top_k must > 0 - stop_token_ids=template_obj.get_stop_token_ids(tokenizer), - max_tokens=generating_args.max_new_tokens, - skip_special_tokens=skip_special_tokens, - seed=seed, - ) - if model_args.adapter_name_or_path is not None: - lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) - else: - lora_request = None engine_args = { "model": model_args.model_name_or_path, @@ -153,14 +105,93 @@ def vllm_infer( if isinstance(model_args.vllm_config, dict): engine_args.update(model_args.vllm_config) - results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) - preds = [result.outputs[0].text for result in results] + llm = LLM(**engine_args) + + # load datasets + dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module) + train_dataset = dataset_module["train_dataset"] + + sampling_params = SamplingParams( + repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 + temperature=generating_args.temperature, + top_p=generating_args.top_p or 1.0, # top_p must > 0 + top_k=generating_args.top_k or -1, # top_k must > 0 + stop_token_ids=template_obj.get_stop_token_ids(tokenizer), + max_tokens=generating_args.max_new_tokens, + skip_special_tokens=skip_special_tokens, + seed=seed, + ) + if model_args.adapter_name_or_path is not None: + lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) + else: + lora_request = None + + # Store all results in these lists + all_prompts = [] + all_preds = [] + all_labels = [] + + # Add batch process to avoid the issue of too many files opened + for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): + vllm_inputs, prompts, labels = [], [], [] + + batch = train_dataset[i : min(i + batch_size, len(train_dataset))] + + for j in range(len(batch["input_ids"])): + if batch["images"][j] is not None: + image = batch["images"][j] + multi_modal_data = { + "image": template_obj.mm_plugin._regularize_images( + image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels + )["images"] + } + elif batch["videos"][j] is not None: + video = batch["videos"][j] + multi_modal_data = { + "video": template_obj.mm_plugin._regularize_videos( + video, + image_max_pixels=image_max_pixels, + image_min_pixels=image_min_pixels, + video_fps=video_fps, + video_maxlen=video_maxlen, + )["videos"] + } + elif batch["audios"][j] is not None: + audio = batch["audios"][j] + audio_data = template_obj.mm_plugin._regularize_audios( + audio, + sampling_rate=16000, + ) + multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} + else: + multi_modal_data = None + + vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}) + prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) + labels.append( + tokenizer.decode( + list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])), + skip_special_tokens=skip_special_tokens, + ) + ) + + results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request) + + preds = [result.outputs[0].text for result in results] + + # Accumulate results + all_prompts.extend(prompts) + all_preds.extend(preds) + all_labels.extend(labels) + + gc.collect() + # Write all results at once outside the loop with open(save_name, "w", encoding="utf-8") as f: - for text, pred, label in zip(prompts, preds, labels): + for text, pred, label in zip(all_prompts, all_preds, all_labels): f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") print("*" * 70) - print(f"{len(prompts)} generated results have been saved at {save_name}.") + print(f"{len(all_prompts)} total generated results have been saved at {save_name}.") print("*" * 70)