From 4831552856b4fb33d6f9f125bc07d80e3dd590e9 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 17 Apr 2025 01:08:55 +0800 Subject: [PATCH] [infer] set env for vllm ascend (#7745) --- examples/README.md | 1 - examples/README_zh.md | 1 - src/llamafactory/cli.py | 20 +++++++++++--------- src/llamafactory/hparams/parser.py | 19 +++++++++++++++---- src/llamafactory/model/patcher.py | 8 ++------ 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/examples/README.md b/examples/README.md index b58ddad8..457ec87f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml #### Batch Generation using vLLM Tensor Parallel ``` -export VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo ``` diff --git a/examples/README_zh.md b/examples/README_zh.md index 1f24dbb2..4899e279 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml #### 使用 vLLM+TP 批量推理 ``` -export VLLM_WORKER_MULTIPROC_METHOD=spawn python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo ``` diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 92085c80..8515e1bb 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -16,10 +16,8 @@ import os import subprocess import sys from copy import deepcopy -from enum import Enum, unique from functools import partial -from .extras import logging USAGE = ( @@ -37,19 +35,20 @@ USAGE = ( + "-" * 70 ) -logger = logging.get_logger(__name__) - def main(): from . import launcher from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval + from .extras import logging from .extras.env import VERSION, print_env from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui + logger = logging.get_logger(__name__) + WELCOME = ( "-" * 58 + "\n" @@ -62,7 +61,7 @@ def main(): + "-" * 58 ) - COMMANDS = { + COMMAND_MAP = { "api": run_api, "chat": run_chat, "env": print_env, @@ -75,9 +74,9 @@ def main(): "help": partial(print, USAGE), } - command = sys.argv.pop(1) if len(sys.argv) != 1 else "help" - force_torchrun = is_env_enabled("FORCE_TORCHRUN") - if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())): + command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" + if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): + # launch distributed training nnodes = os.getenv("NNODES", "1") node_rank = os.getenv("NODE_RANK", "0") nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) @@ -113,11 +112,14 @@ def main(): check=True, ) sys.exit(process.returncode) + elif command in COMMAND_MAP: + COMMAND_MAP[command]() else: - COMMANDS[command]() + print(f"Unknown command: {command}.\n{USAGE}") if __name__ == "__main__": from multiprocessing import freeze_support + freeze_support() main() diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 8d2e9c5b..bc200f60 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -91,6 +91,14 @@ def _set_transformers_logging() -> None: transformers.utils.logging.enable_explicit_format() +def _set_env_vars() -> None: + if is_torch_npu_available(): + # avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458 + torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE")) + # avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447 + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + def _verify_model_args( model_args: "ModelArguments", data_args: "DataArguments", @@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") - if model_args.infer_backend == "vllm": - raise ValueError("vLLM backend is only available for API, CLI and Web.") + if model_args.infer_backend != EngineName.HF: + raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") if model_args.use_unsloth and is_deepspeed_zero3_enabled(): raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + _set_env_vars() _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) @@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("vLLM only accepts a single adapter. Merge them first.") + _set_env_vars() _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) @@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E _set_transformers_logging() # Check arguments - if model_args.infer_backend == "vllm": - raise ValueError("vLLM backend is only available for API, CLI and Web.") + if model_args.infer_backend != EngineName.HF: + raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") + _set_env_vars() _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 996fe7ef..28014de9 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any import torch from peft import PeftModel -from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available +from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras import logging -from ..extras.misc import infer_optim_dtype, is_env_enabled +from ..extras.misc import infer_optim_dtype from ..extras.packages import is_transformers_version_greater_than from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training @@ -95,10 +95,6 @@ def patch_config( else: model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) - if is_torch_npu_available(): - # avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458 - torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE")) - configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable)