[infer] support lora adapter for SGLang backend (#8067)

This commit is contained in:
Saiya 2025-05-16 23:33:47 +08:00 committed by GitHub
parent 66f719dd96
commit 820ed764c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 1 deletions

View File

@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.template.mm_plugin.expand_mm_tokens = False # for sglang generate
self.generating_args = generating_args.to_dict()
if model_args.adapter_name_or_path is not None:
self.lora_request = True
else:
self.lora_request = False
launch_cmd = [
"python3 -m sglang.launch_server",
@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
f"--download-dir {model_args.cache_dir}",
"--log-level error",
]
if self.lora_request:
launch_cmd.extend(
[
"--max-loras-per-batch 1",
f"--lora-backend {model_args.sglang_lora_backend}",
f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
"--disable-radix-cache",
]
)
launch_cmd = " ".join(launch_cmd)
logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
try:
@ -202,6 +215,8 @@ class SGLangEngine(BaseEngine):
"sampling_params": sampling_params,
"stream": True,
}
if self.lora_request:
json_data["lora_request"] = ["lora0"]
response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
if response.status_code != 200:
raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")

View File

@ -364,6 +364,12 @@ class SGLangArguments:
default=None,
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
)
sglang_lora_backend: Literal["triton", "flashinfer"] = field(
default="triton",
metadata={
"help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
},
)
def __post_init__(self):
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):

View File

@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel
from llamafactory.extras.packages import is_sglang_available
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
INFER_ARGS = {