mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[infer] set env for vllm ascend (#7745)
This commit is contained in:
parent
125513fa5c
commit
4831552856
@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### Batch Generation using vLLM Tensor Parallel
|
||||
|
||||
```
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||
```
|
||||
|
||||
|
@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
||||
#### 使用 vLLM+TP 批量推理
|
||||
|
||||
```
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||
```
|
||||
|
||||
|
@ -16,10 +16,8 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from functools import partial
|
||||
|
||||
from .extras import logging
|
||||
|
||||
|
||||
USAGE = (
|
||||
@ -37,19 +35,20 @@ USAGE = (
|
||||
+ "-" * 70
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras import logging
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
@ -62,7 +61,7 @@ def main():
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
COMMANDS = {
|
||||
COMMAND_MAP = {
|
||||
"api": run_api,
|
||||
"chat": run_chat,
|
||||
"env": print_env,
|
||||
@ -75,9 +74,9 @@ def main():
|
||||
"help": partial(print, USAGE),
|
||||
}
|
||||
|
||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
|
||||
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
||||
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
|
||||
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
|
||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||
# launch distributed training
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||
@ -113,11 +112,14 @@ def main():
|
||||
check=True,
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
elif command in COMMAND_MAP:
|
||||
COMMAND_MAP[command]()
|
||||
else:
|
||||
COMMANDS[command]()
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from multiprocessing import freeze_support
|
||||
|
||||
freeze_support()
|
||||
main()
|
||||
|
@ -91,6 +91,14 @@ def _set_transformers_logging() -> None:
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _set_env_vars() -> None:
|
||||
if is_torch_npu_available():
|
||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
|
||||
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _verify_model_args(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
|
||||
_set_transformers_logging()
|
||||
|
||||
# Check arguments
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
|
@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import infer_optim_dtype, is_env_enabled
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_transformers_version_greater_than
|
||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||
from .model_utils.checkpointing import prepare_model_for_training
|
||||
@ -95,10 +95,6 @@ def patch_config(
|
||||
else:
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
|
||||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
|
Loading…
x
Reference in New Issue
Block a user