mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
pin vllm version to 0.6.5 (#6629)
Former-commit-id: 1c7663d3049e00a9148c3e3c58204deca7a08c8d
This commit is contained in:
parent
201a495154
commit
5e699458e5
@ -19,17 +19,12 @@ from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.misc import get_device_count
|
||||
from llamafactory.extras.packages import is_pillow_available, is_vllm_available
|
||||
from llamafactory.extras.misc import check_version, get_device_count
|
||||
from llamafactory.extras.packages import is_vllm_available
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -51,11 +46,13 @@ def vllm_infer(
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_resolution: int = 512 * 512,
|
||||
):
|
||||
r"""
|
||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||
"""
|
||||
check_version("vllm>=0.4.3,<=0.6.5")
|
||||
if pipeline_parallel_size > get_device_count():
|
||||
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||
|
||||
@ -88,15 +85,9 @@ def vllm_infer(
|
||||
inputs, prompts, labels = [], [], []
|
||||
for sample in dataset_module["train_dataset"]:
|
||||
if sample["images"]:
|
||||
multi_modal_data = {"image": []}
|
||||
for image in sample["images"]:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data["image"].append(image)
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
|
||||
}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -54,7 +54,7 @@ extra_require = {
|
||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<0.6.7"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.6.5"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
|
@ -133,7 +133,7 @@ def _check_extra_dependencies(
|
||||
check_version("mixture-of-depth>=1.1.6", mandatory=True)
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
check_version("vllm>=0.4.3,<0.6.7")
|
||||
check_version("vllm>=0.4.3,<=0.6.5")
|
||||
check_version("vllm", mandatory=True)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
|
Loading…
x
Reference in New Issue
Block a user