[infer] support vllm-ascend (#7739)

This commit is contained in:
leo-pony 2025-04-16 20:06:47 +08:00 committed by GitHub
parent d07983dceb
commit e1fdd6e2f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 22 deletions

View File

@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### Batch Generation using vLLM Tensor Parallel
```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```

View File

@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### 使用 vLLM+TP 批量推理
```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```

View File

@ -18,16 +18,7 @@ import sys
from copy import deepcopy
from enum import Enum, unique
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
USAGE = (
"-" * 70
@ -44,18 +35,6 @@ USAGE = (
+ "-" * 70
)
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
logger = logging.get_logger(__name__)
@ -72,8 +51,28 @@ class Command(str, Enum):
VER = "version"
HELP = "help"
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
WELCOME = (
"-" * 58
+ "\n"
+ f"| Welcome to LLaMA Factory, version {VERSION}"
+ " " * (21 - len(VERSION))
+ "|\n|"
+ " " * 56
+ "|\n"
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
+ "-" * 58
)
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
if command == Command.API:
run_api()
@ -138,4 +137,6 @@ def main():
if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
main()