mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[infer] Modify vllm_infer.py to batch preprocess to avoid too much files opened error (#8051)
Co-authored-by: Kingsley <82590017+Kuangdd01@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									712c57f3b4
								
							
						
					
					
						commit
						0b773234e5
					
				@ -12,11 +12,13 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import json
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.extras.constants import IGNORE_INDEX
 | 
			
		||||
@ -53,6 +55,7 @@ def vllm_infer(
 | 
			
		||||
    image_min_pixels: int = 32 * 32,
 | 
			
		||||
    video_fps: float = 2.0,
 | 
			
		||||
    video_maxlen: int = 128,
 | 
			
		||||
    batch_size: int = 1024,
 | 
			
		||||
):
 | 
			
		||||
    r"""Perform batch generation using vLLM engine, which supports tensor parallelism.
 | 
			
		||||
 | 
			
		||||
@ -85,57 +88,6 @@ def vllm_infer(
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    template_obj.mm_plugin.expand_mm_tokens = False  # for vllm generate
 | 
			
		||||
    dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
 | 
			
		||||
 | 
			
		||||
    inputs, prompts, labels = [], [], []
 | 
			
		||||
    for sample in dataset_module["train_dataset"]:
 | 
			
		||||
        if sample["images"]:
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "image": template_obj.mm_plugin._regularize_images(
 | 
			
		||||
                    sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
 | 
			
		||||
                )["images"]
 | 
			
		||||
            }
 | 
			
		||||
        elif sample["videos"]:
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "video": template_obj.mm_plugin._regularize_videos(
 | 
			
		||||
                    sample["videos"],
 | 
			
		||||
                    image_max_pixels=image_max_pixels,
 | 
			
		||||
                    image_min_pixels=image_min_pixels,
 | 
			
		||||
                    video_fps=video_fps,
 | 
			
		||||
                    video_maxlen=video_maxlen,
 | 
			
		||||
                )["videos"]
 | 
			
		||||
            }
 | 
			
		||||
        elif sample["audios"]:
 | 
			
		||||
            audio_data = template_obj.mm_plugin._regularize_audios(
 | 
			
		||||
                sample["audios"],
 | 
			
		||||
                sampling_rate=16000,
 | 
			
		||||
            )
 | 
			
		||||
            multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
        inputs.append({"prompt_token_ids": sample["input_ids"], "multi_modal_data": multi_modal_data})
 | 
			
		||||
        prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=skip_special_tokens))
 | 
			
		||||
        labels.append(
 | 
			
		||||
            tokenizer.decode(
 | 
			
		||||
                list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=skip_special_tokens
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
        repetition_penalty=generating_args.repetition_penalty or 1.0,  # repetition_penalty must > 0
 | 
			
		||||
        temperature=generating_args.temperature,
 | 
			
		||||
        top_p=generating_args.top_p or 1.0,  # top_p must > 0
 | 
			
		||||
        top_k=generating_args.top_k or -1,  # top_k must > 0
 | 
			
		||||
        stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
 | 
			
		||||
        max_tokens=generating_args.max_new_tokens,
 | 
			
		||||
        skip_special_tokens=skip_special_tokens,
 | 
			
		||||
        seed=seed,
 | 
			
		||||
    )
 | 
			
		||||
    if model_args.adapter_name_or_path is not None:
 | 
			
		||||
        lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
 | 
			
		||||
    else:
 | 
			
		||||
        lora_request = None
 | 
			
		||||
 | 
			
		||||
    engine_args = {
 | 
			
		||||
        "model": model_args.model_name_or_path,
 | 
			
		||||
@ -153,14 +105,93 @@ def vllm_infer(
 | 
			
		||||
    if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
        engine_args.update(model_args.vllm_config)
 | 
			
		||||
 | 
			
		||||
    results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
 | 
			
		||||
    preds = [result.outputs[0].text for result in results]
 | 
			
		||||
    llm = LLM(**engine_args)
 | 
			
		||||
 | 
			
		||||
    # load datasets
 | 
			
		||||
    dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
 | 
			
		||||
    train_dataset = dataset_module["train_dataset"]
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
        repetition_penalty=generating_args.repetition_penalty or 1.0,  # repetition_penalty must > 0
 | 
			
		||||
        temperature=generating_args.temperature,
 | 
			
		||||
        top_p=generating_args.top_p or 1.0,  # top_p must > 0
 | 
			
		||||
        top_k=generating_args.top_k or -1,  # top_k must > 0
 | 
			
		||||
        stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
 | 
			
		||||
        max_tokens=generating_args.max_new_tokens,
 | 
			
		||||
        skip_special_tokens=skip_special_tokens,
 | 
			
		||||
        seed=seed,
 | 
			
		||||
    )
 | 
			
		||||
    if model_args.adapter_name_or_path is not None:
 | 
			
		||||
        lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
 | 
			
		||||
    else:
 | 
			
		||||
        lora_request = None
 | 
			
		||||
 | 
			
		||||
    # Store all results in these lists
 | 
			
		||||
    all_prompts = []
 | 
			
		||||
    all_preds = []
 | 
			
		||||
    all_labels = []
 | 
			
		||||
 | 
			
		||||
    # Add batch process to avoid the issue of too many files opened
 | 
			
		||||
    for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
 | 
			
		||||
        vllm_inputs, prompts, labels = [], [], []
 | 
			
		||||
 | 
			
		||||
        batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
 | 
			
		||||
 | 
			
		||||
        for j in range(len(batch["input_ids"])):
 | 
			
		||||
            if batch["images"][j] is not None:
 | 
			
		||||
                image = batch["images"][j]
 | 
			
		||||
                multi_modal_data = {
 | 
			
		||||
                    "image": template_obj.mm_plugin._regularize_images(
 | 
			
		||||
                        image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
 | 
			
		||||
                    )["images"]
 | 
			
		||||
                }
 | 
			
		||||
            elif batch["videos"][j] is not None:
 | 
			
		||||
                video = batch["videos"][j]
 | 
			
		||||
                multi_modal_data = {
 | 
			
		||||
                    "video": template_obj.mm_plugin._regularize_videos(
 | 
			
		||||
                        video,
 | 
			
		||||
                        image_max_pixels=image_max_pixels,
 | 
			
		||||
                        image_min_pixels=image_min_pixels,
 | 
			
		||||
                        video_fps=video_fps,
 | 
			
		||||
                        video_maxlen=video_maxlen,
 | 
			
		||||
                    )["videos"]
 | 
			
		||||
                }
 | 
			
		||||
            elif batch["audios"][j] is not None:
 | 
			
		||||
                audio = batch["audios"][j]
 | 
			
		||||
                audio_data = template_obj.mm_plugin._regularize_audios(
 | 
			
		||||
                    audio,
 | 
			
		||||
                    sampling_rate=16000,
 | 
			
		||||
                )
 | 
			
		||||
                multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
 | 
			
		||||
            else:
 | 
			
		||||
                multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
            vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
 | 
			
		||||
            prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
 | 
			
		||||
            labels.append(
 | 
			
		||||
                tokenizer.decode(
 | 
			
		||||
                    list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
 | 
			
		||||
                    skip_special_tokens=skip_special_tokens,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
 | 
			
		||||
 | 
			
		||||
        preds = [result.outputs[0].text for result in results]
 | 
			
		||||
 | 
			
		||||
        # Accumulate results
 | 
			
		||||
        all_prompts.extend(prompts)
 | 
			
		||||
        all_preds.extend(preds)
 | 
			
		||||
        all_labels.extend(labels)
 | 
			
		||||
 | 
			
		||||
        gc.collect()
 | 
			
		||||
    # Write all results at once outside the loop
 | 
			
		||||
    with open(save_name, "w", encoding="utf-8") as f:
 | 
			
		||||
        for text, pred, label in zip(prompts, preds, labels):
 | 
			
		||||
        for text, pred, label in zip(all_prompts, all_preds, all_labels):
 | 
			
		||||
            f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
 | 
			
		||||
 | 
			
		||||
    print("*" * 70)
 | 
			
		||||
    print(f"{len(prompts)} generated results have been saved at {save_name}.")
 | 
			
		||||
    print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
 | 
			
		||||
    print("*" * 70)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user