[infer] set env for vllm ascend (#7745)

This commit is contained in:
hoshi-hiyouga 2025-04-17 01:08:55 +08:00 committed by GitHub
parent 125513fa5c
commit 4831552856
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 21 deletions

View File

@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### Batch Generation using vLLM Tensor Parallel #### Batch Generation using vLLM Tensor Parallel
``` ```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
``` ```

View File

@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
#### 使用 vLLM+TP 批量推理 #### 使用 vLLM+TP 批量推理
``` ```
export VLLM_WORKER_MULTIPROC_METHOD=spawn
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
``` ```

View File

@ -16,10 +16,8 @@ import os
import subprocess import subprocess
import sys import sys
from copy import deepcopy from copy import deepcopy
from enum import Enum, unique
from functools import partial from functools import partial
from .extras import logging
USAGE = ( USAGE = (
@ -37,19 +35,20 @@ USAGE = (
+ "-" * 70 + "-" * 70
) )
logger = logging.get_logger(__name__)
def main(): def main():
from . import launcher from . import launcher
from .api.app import run_api from .api.app import run_api
from .chat.chat_model import run_chat from .chat.chat_model import run_chat
from .eval.evaluator import run_eval from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = ( WELCOME = (
"-" * 58 "-" * 58
+ "\n" + "\n"
@ -62,7 +61,7 @@ def main():
+ "-" * 58 + "-" * 58
) )
COMMANDS = { COMMAND_MAP = {
"api": run_api, "api": run_api,
"chat": run_chat, "chat": run_chat,
"env": print_env, "env": print_env,
@ -75,9 +74,9 @@ def main():
"help": partial(print, USAGE), "help": partial(print, USAGE),
} }
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help" command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN") if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())): # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0") node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@ -113,11 +112,14 @@ def main():
check=True, check=True,
) )
sys.exit(process.returncode) sys.exit(process.returncode)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
else: else:
COMMANDS[command]() print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__": if __name__ == "__main__":
from multiprocessing import freeze_support from multiprocessing import freeze_support
freeze_support() freeze_support()
main() main()

View File

@ -91,6 +91,14 @@ def _set_transformers_logging() -> None:
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
def _set_env_vars() -> None:
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _verify_model_args( def _verify_model_args(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if model_args.infer_backend == "vllm": if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
if model_args.use_unsloth and is_deepspeed_zero3_enabled(): if model_args.use_unsloth and is_deepspeed_zero3_enabled():
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("vLLM only accepts a single adapter. Merge them first.") raise ValueError("vLLM only accepts a single adapter. Merge them first.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)
@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
_set_transformers_logging() _set_transformers_logging()
# Check arguments # Check arguments
if model_args.infer_backend == "vllm": if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
_set_env_vars()
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args)

View File

@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from peft import PeftModel from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled from transformers.modeling_utils import is_fsdp_enabled
from ..extras import logging from ..extras import logging
from ..extras.misc import infer_optim_dtype, is_env_enabled from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_transformers_version_greater_than from ..extras.packages import is_transformers_version_greater_than
from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.attention import configure_attn_implementation, print_attn_implementation
from .model_utils.checkpointing import prepare_model_for_training from .model_utils.checkpointing import prepare_model_for_training
@ -95,10 +95,6 @@ def patch_config(
else: else:
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
configure_attn_implementation(config, model_args, is_trainable) configure_attn_implementation(config, model_args, is_trainable)
configure_rope(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)