mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[infer] support lora adapter for SGLang backend (#8067)
This commit is contained in:
parent
66f719dd96
commit
820ed764c4
@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
|
|||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||||
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
|
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
self.lora_request = True
|
||||||
|
else:
|
||||||
|
self.lora_request = False
|
||||||
|
|
||||||
launch_cmd = [
|
launch_cmd = [
|
||||||
"python3 -m sglang.launch_server",
|
"python3 -m sglang.launch_server",
|
||||||
@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
|
|||||||
f"--download-dir {model_args.cache_dir}",
|
f"--download-dir {model_args.cache_dir}",
|
||||||
"--log-level error",
|
"--log-level error",
|
||||||
]
|
]
|
||||||
|
if self.lora_request:
|
||||||
|
launch_cmd.extend(
|
||||||
|
[
|
||||||
|
"--max-loras-per-batch 1",
|
||||||
|
f"--lora-backend {model_args.sglang_lora_backend}",
|
||||||
|
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
|
||||||
|
"--disable-radix-cache",
|
||||||
|
]
|
||||||
|
)
|
||||||
launch_cmd = " ".join(launch_cmd)
|
launch_cmd = " ".join(launch_cmd)
|
||||||
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
|
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
|
||||||
try:
|
try:
|
||||||
@ -202,6 +215,8 @@ class SGLangEngine(BaseEngine):
|
|||||||
"sampling_params": sampling_params,
|
"sampling_params": sampling_params,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
if self.lora_request:
|
||||||
|
json_data["lora_request"] = ["lora0"]
|
||||||
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
|
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
|
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
|
||||||
|
@ -364,6 +364,12 @@ class SGLangArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||||
)
|
)
|
||||||
|
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
|
||||||
|
default="triton",
|
||||||
|
metadata={
|
||||||
|
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
|
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
|
||||||
|
@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel
|
|||||||
from llamafactory.extras.packages import is_sglang_available
|
from llamafactory.extras.packages import is_sglang_available
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
|
||||||
|
|
||||||
|
|
||||||
INFER_ARGS = {
|
INFER_ARGS = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user