[infer] Modify vllm_infer.py to batch preprocess to avoid too much files opened error (#8051)

Co-authored-by: Kingsley <82590017+Kuangdd01@users.noreply.github.com>
This commit is contained in:
Shawn Tao 2025-05-15 10:54:35 +08:00 committed by GitHub
parent 2b23c0a7a1
commit e8a18c17e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import json import json
from typing import Optional from typing import Optional
import fire import fire
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from tqdm import tqdm
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
@ -53,6 +55,7 @@ def vllm_infer(
image_min_pixels: int = 32 * 32, image_min_pixels: int = 32 * 32,
video_fps: float = 2.0, video_fps: float = 2.0,
video_maxlen: int = 128, video_maxlen: int = 128,
batch_size: int = 1024,
): ):
r"""Perform batch generation using vLLM engine, which supports tensor parallelism. r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
@ -85,57 +88,6 @@ def vllm_infer(
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
inputs, prompts, labels = [], [], []
for sample in dataset_module["train_dataset"]:
if sample["images"]:
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"]
}
elif sample["videos"]:
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
sample["videos"],
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_fps=video_fps,
video_maxlen=video_maxlen,
)["videos"]
}
elif sample["audios"]:
audio_data = template_obj.mm_plugin._regularize_audios(
sample["audios"],
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
)
)
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_request = None
engine_args = { engine_args = {
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
@ -153,14 +105,93 @@ def vllm_infer(
if isinstance(model_args.vllm_config, dict): if isinstance(model_args.vllm_config, dict):
engine_args.update(model_args.vllm_config) engine_args.update(model_args.vllm_config)
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) llm = LLM(**engine_args)
preds = [result.outputs[0].text for result in results]
# load datasets
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
train_dataset = dataset_module["train_dataset"]
sampling_params = SamplingParams(
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
temperature=generating_args.temperature,
top_p=generating_args.top_p or 1.0, # top_p must > 0
top_k=generating_args.top_k or -1, # top_k must > 0
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
max_tokens=generating_args.max_new_tokens,
skip_special_tokens=skip_special_tokens,
seed=seed,
)
if model_args.adapter_name_or_path is not None:
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
else:
lora_request = None
# Store all results in these lists
all_prompts = []
all_preds = []
all_labels = []
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])):
if batch["images"][j] is not None:
image = batch["images"][j]
multi_modal_data = {
"image": template_obj.mm_plugin._regularize_images(
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"]
}
elif batch["videos"][j] is not None:
video = batch["videos"][j]
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
video,
image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels,
video_fps=video_fps,
video_maxlen=video_maxlen,
)["videos"]
}
elif batch["audios"][j] is not None:
audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios(
audio,
sampling_rate=16000,
)
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
else:
multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
skip_special_tokens=skip_special_tokens,
)
)
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
gc.collect()
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
print("*" * 70) print("*" * 70)
print(f"{len(prompts)} generated results have been saved at {save_name}.") print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
print("*" * 70) print("*" * 70)