From e1fdd6e2f89ef679d1645755c79d981dfe1ad29e Mon Sep 17 00:00:00 2001 From: leo-pony Date: Wed, 16 Apr 2025 20:06:47 +0800 Subject: [PATCH] [infer] support vllm-ascend (#7739) --- examples/README.md | 1 + examples/README_zh.md | 1 + src/llamafactory/cli.py | 45 +++++++++++++++++++++-------------------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/examples/README.md b/examples/README.md index 457ec87f..b58ddad8 100644 --- a/examples/README.md +++ b/examples/README.md @@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml #### Batch Generation using vLLM Tensor Parallel ``` +export VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 4899e279..1f24dbb2 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml #### 使用 vLLM+TP 批量推理 ``` +export VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo ``` diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 7e7e06cf..99a089b2 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -18,16 +18,7 @@ import sys from copy import deepcopy from enum import Enum, unique -from . import launcher -from .api.app import run_api -from .chat.chat_model import run_chat -from .eval.evaluator import run_eval from .extras import logging -from .extras.env import VERSION, print_env -from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray -from .train.tuner import export_model, run_exp -from .webui.interface import run_web_demo, run_web_ui - USAGE = ( "-" * 70 @@ -44,18 +35,6 @@ USAGE = ( + "-" * 70 ) -WELCOME = ( - "-" * 58 - + "\n" - + f"| Welcome to LLaMA Factory, version {VERSION}" - + " " * (21 - len(VERSION)) - + "|\n|" - + " " * 56 - + "|\n" - + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" - + "-" * 58 -) - logger = logging.get_logger(__name__) @@ -72,8 +51,28 @@ class Command(str, Enum): VER = "version" HELP = "help" - def main(): + from . import launcher + from .api.app import run_api + from .chat.chat_model import run_chat + from .eval.evaluator import run_eval + from .extras.env import VERSION, print_env + from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray + from .train.tuner import export_model, run_exp + from .webui.interface import run_web_demo, run_web_ui + + WELCOME = ( + "-" * 58 + + "\n" + + f"| Welcome to LLaMA Factory, version {VERSION}" + + " " * (21 - len(VERSION)) + + "|\n|" + + " " * 56 + + "|\n" + + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" + + "-" * 58 + ) + command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP if command == Command.API: run_api() @@ -138,4 +137,6 @@ def main(): if __name__ == "__main__": + from multiprocessing import freeze_support + freeze_support() main()