diff --git a/setup.py b/setup.py index 7af9b825..6fe7180d 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ extra_require = { "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3,<=0.7.3"], + "vllm": ["vllm>=0.4.3,<=0.8.1"], "sglang": ["sglang[srt]>=0.4.4", "transformers==4.48.3"], "galore": ["galore-torch"], "apollo": ["apollo-torch"], diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 33aa735b..075ef845 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import random import subprocess import sys from enum import Enum, unique @@ -24,7 +23,7 @@ from .chat.chat_model import run_chat from .eval.evaluator import run_eval from .extras import logging from .extras.env import VERSION, print_env -from .extras.misc import get_device_count, is_env_enabled, use_ray +from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -92,7 +91,7 @@ def main(): node_rank = os.getenv("NODE_RANK", "0") nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") - master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) + master_port = os.getenv("MASTER_PORT", str(find_available_port())) logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") if int(nnodes) > 1: print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 0b3bf759..66bea4f3 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -135,7 +135,7 @@ def _check_extra_dependencies( check_version("mixture-of-depth>=1.1.6", mandatory=True) if model_args.infer_backend == EngineName.VLLM: - check_version("vllm>=0.4.3,<=0.7.3") + check_version("vllm>=0.4.3,<=0.8.1") check_version("vllm", mandatory=True) elif model_args.infer_backend == EngineName.SGLANG: check_version("sglang>=0.4.4") diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 575cd584..ebdd9e7a 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -122,6 +122,7 @@ class WebChatModel(ChatModel): enable_liger_kernel=(get("top.booster") == "liger_kernel"), infer_backend=get("infer.infer_backend"), infer_dtype=get("infer.infer_dtype"), + vllm_enforce_eager=True, trust_remote_code=True, )