[infer] set env for vllm ascend (#7745)

This commit is contained in:
hoshi-hiyouga 2025-04-17 01:08:55 +08:00 committed by GitHub
parent 125513fa5c
commit 4831552856
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 21 deletions

View File

@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### Batch Generation using vLLM Tensor Parallel
```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```

View File

@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### 使用 vLLM+TP 批量推理
```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
```

View File

@ -16,10 +16,8 @@ import os
import subprocess
import sys
from copy import deepcopy
from enum import Enum, unique
from functools import partial
from .extras import logging
USAGE = (
@ -37,19 +35,20 @@ USAGE = (
+ "-" * 70
)
logger = logging.get_logger(__name__)
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
@ -62,7 +61,7 @@ def main():
+ "-" * 58
)
COMMANDS = {
COMMAND_MAP = {
"api": run_api,
"chat": run_chat,
"env": print_env,
@ -75,9 +74,9 @@ def main():
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@ -113,11 +112,14 @@ def main():
check=True,
)
sys.exit(process.returncode)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
else:
COMMANDS[command]()
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
main()

View File

@ -91,6 +91,14 @@ def _set_transformers_logging() -> None:
transformers.utils.logging.enable_explicit_format()
def _set_env_vars() -> None:
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _verify_model_args(
model_args: "ModelArguments",
data_args: "DataArguments",
@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args)
@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)
@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging()
# Check arguments
if model_args.infer_backend == "vllm":
raise ValueError("vLLM backend is only available for API, CLI and Web.")
if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args)

View File

@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging
from ..extras.misc import infer_optim_dtype, is_env_enabled
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training
@ -95,10 +95,6 @@ def patch_config(
else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)