mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[infer] support vllm-ascend (#7739)
This commit is contained in:
parent
d07983dceb
commit
e1fdd6e2f8
@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### Batch Generation using vLLM Tensor Parallel
|
||||
|
||||
```
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||
```
|
||||
|
||||
|
@ -181,6 +181,7 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### 使用 vLLM+TP 批量推理
|
||||
|
||||
```
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||
```
|
||||
|
||||
|
@ -18,16 +18,7 @@ import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras import logging
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
@ -44,18 +35,6 @@ USAGE = (
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -72,8 +51,28 @@ class Command(str, Enum):
|
||||
VER = "version"
|
||||
HELP = "help"
|
||||
|
||||
|
||||
def main():
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
+ "|\n"
|
||||
+ "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n"
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
@ -138,4 +137,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from multiprocessing import freeze_support
|
||||
freeze_support()
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user