mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[infer] Modify vllm_infer.py to batch preprocess to avoid too much files opened error (#8051)
Co-authored-by: Kingsley <82590017+Kuangdd01@users.noreply.github.com>
This commit is contained in:
parent
2b23c0a7a1
commit
e8a18c17e9
@ -12,11 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from tqdm import tqdm
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
@ -53,6 +55,7 @@ def vllm_infer(
|
||||
image_min_pixels: int = 32 * 32,
|
||||
video_fps: float = 2.0,
|
||||
video_maxlen: int = 128,
|
||||
batch_size: int = 1024,
|
||||
):
|
||||
r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
|
||||
|
||||
@ -85,57 +88,6 @@ def vllm_infer(
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
|
||||
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
|
||||
|
||||
inputs, prompts, labels = [], [], []
|
||||
for sample in dataset_module["train_dataset"]:
|
||||
if sample["images"]:
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif sample["videos"]:
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
sample["videos"],
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif sample["audios"]:
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
sample["audios"],
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
|
||||
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
|
||||
labels.append(
|
||||
tokenizer.decode(
|
||||
list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||
temperature=generating_args.temperature,
|
||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||
top_k=generating_args.top_k or -1, # top_k must > 0
|
||||
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||
max_tokens=generating_args.max_new_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
seed=seed,
|
||||
)
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
lora_request = None
|
||||
|
||||
engine_args = {
|
||||
"model": model_args.model_name_or_path,
|
||||
@ -153,14 +105,93 @@ def vllm_infer(
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# load datasets
|
||||
dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
|
||||
train_dataset = dataset_module["train_dataset"]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||
temperature=generating_args.temperature,
|
||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||
top_k=generating_args.top_k or -1, # top_k must > 0
|
||||
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||
max_tokens=generating_args.max_new_tokens,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
seed=seed,
|
||||
)
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||
else:
|
||||
lora_request = None
|
||||
|
||||
# Store all results in these lists
|
||||
all_prompts = []
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
# Add batch process to avoid the issue of too many files opened
|
||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||
vllm_inputs, prompts, labels = [], [], []
|
||||
|
||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||
|
||||
for j in range(len(batch["input_ids"])):
|
||||
if batch["images"][j] is not None:
|
||||
image = batch["images"][j]
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
|
||||
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
||||
labels.append(
|
||||
tokenizer.decode(
|
||||
list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
|
||||
|
||||
preds = [result.outputs[0].text for result in results]
|
||||
|
||||
# Accumulate results
|
||||
all_prompts.extend(prompts)
|
||||
all_preds.extend(preds)
|
||||
all_labels.extend(labels)
|
||||
|
||||
gc.collect()
|
||||
# Write all results at once outside the loop
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
for text, pred, label in zip(prompts, preds, labels):
|
||||
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
||||
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||
|
||||
print("*" * 70)
|
||||
print(f"{len(prompts)} generated results have been saved at {save_name}.")
|
||||
print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
|
||||
print("*" * 70)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user