mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[infer] set env for vllm ascend (#7745)
This commit is contained in:
parent
125513fa5c
commit
4831552856
@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
|||||||
#### Batch Generation using vLLM Tensor Parallel
|
#### Batch Generation using vLLM Tensor Parallel
|
||||||
|
|
||||||
```
|
```
|
||||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -181,7 +181,6 @@ llamafactory-cli export examples/merge_lora/llama3_full_sft.yaml
|
|||||||
#### 使用 vLLM+TP 批量推理
|
#### 使用 vLLM+TP 批量推理
|
||||||
|
|
||||||
```
|
```
|
||||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -16,10 +16,8 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import Enum, unique
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from .extras import logging
|
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
@ -37,19 +35,20 @@ USAGE = (
|
|||||||
+ "-" * 70
|
+ "-" * 70
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
from . import launcher
|
from . import launcher
|
||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
from .chat.chat_model import run_chat
|
from .chat.chat_model import run_chat
|
||||||
from .eval.evaluator import run_eval
|
from .eval.evaluator import run_eval
|
||||||
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
WELCOME = (
|
WELCOME = (
|
||||||
"-" * 58
|
"-" * 58
|
||||||
+ "\n"
|
+ "\n"
|
||||||
@ -62,7 +61,7 @@ def main():
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMANDS = {
|
COMMAND_MAP = {
|
||||||
"api": run_api,
|
"api": run_api,
|
||||||
"chat": run_chat,
|
"chat": run_chat,
|
||||||
"env": print_env,
|
"env": print_env,
|
||||||
@ -75,9 +74,9 @@ def main():
|
|||||||
"help": partial(print, USAGE),
|
"help": partial(print, USAGE),
|
||||||
}
|
}
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
|
||||||
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
|
# launch distributed training
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
node_rank = os.getenv("NODE_RANK", "0")
|
node_rank = os.getenv("NODE_RANK", "0")
|
||||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||||
@ -113,11 +112,14 @@ def main():
|
|||||||
check=True,
|
check=True,
|
||||||
)
|
)
|
||||||
sys.exit(process.returncode)
|
sys.exit(process.returncode)
|
||||||
|
elif command in COMMAND_MAP:
|
||||||
|
COMMAND_MAP[command]()
|
||||||
else:
|
else:
|
||||||
COMMANDS[command]()
|
print(f"Unknown command: {command}.\n{USAGE}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from multiprocessing import freeze_support
|
from multiprocessing import freeze_support
|
||||||
|
|
||||||
freeze_support()
|
freeze_support()
|
||||||
main()
|
main()
|
||||||
|
@ -91,6 +91,14 @@ def _set_transformers_logging() -> None:
|
|||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
|
||||||
|
def _set_env_vars() -> None:
|
||||||
|
if is_torch_npu_available():
|
||||||
|
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||||
|
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
|
||||||
|
# avoid use fork method on NPU devices, see https://github.com/hiyouga/LLaMA-Factory/issues/7447
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_args(
|
def _verify_model_args(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
@ -279,12 +287,13 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend != EngineName.HF:
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
|
_set_env_vars()
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
@ -407,6 +416,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
_set_env_vars()
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
@ -428,9 +438,10 @@ def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _E
|
|||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
|
|
||||||
# Check arguments
|
# Check arguments
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend != EngineName.HF:
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
|
_set_env_vars()
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
|
@ -17,12 +17,12 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import infer_optim_dtype, is_env_enabled
|
from ..extras.misc import infer_optim_dtype
|
||||||
from ..extras.packages import is_transformers_version_greater_than
|
from ..extras.packages import is_transformers_version_greater_than
|
||||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||||
from .model_utils.checkpointing import prepare_model_for_training
|
from .model_utils.checkpointing import prepare_model_for_training
|
||||||
@ -95,10 +95,6 @@ def patch_config(
|
|||||||
else:
|
else:
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
|
||||||
if is_torch_npu_available():
|
|
||||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
|
||||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("NPU_JIT_COMPILE"))
|
|
||||||
|
|
||||||
configure_attn_implementation(config, model_args, is_trainable)
|
configure_attn_implementation(config, model_args, is_trainable)
|
||||||
configure_rope(config, model_args, is_trainable)
|
configure_rope(config, model_args, is_trainable)
|
||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user