diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py index 24d60604..99ca04ae 100644 --- a/src/llamafactory/chat/sglang_engine.py +++ b/src/llamafactory/chat/sglang_engine.py @@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine): self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.template.mm_plugin.expand_mm_tokens = False # for sglang generate self.generating_args = generating_args.to_dict() + if model_args.adapter_name_or_path is not None: + self.lora_request = True + else: + self.lora_request = False launch_cmd = [ "python3 -m sglang.launch_server", @@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine): f"--download-dir {model_args.cache_dir}", "--log-level error", ] + if self.lora_request: + launch_cmd.extend( + [ + "--max-loras-per-batch 1", + f"--lora-backend {model_args.sglang_lora_backend}", + f"--lora-paths lora0={model_args.adapter_name_or_path[0]}", + "--disable-radix-cache", + ] + ) launch_cmd = " ".join(launch_cmd) logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") try: @@ -202,6 +215,8 @@ class SGLangEngine(BaseEngine): "sampling_params": sampling_params, "stream": True, } + if self.lora_request: + json_data["lora_request"] = ["lora0"] response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) if response.status_code != 200: raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index e7a74046..eec9ceca 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -364,6 +364,12 @@ class SGLangArguments: default=None, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, ) + sglang_lora_backend: Literal["triton", "flashinfer"] = field( + default="triton", + metadata={ + "help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability." + }, + ) def __post_init__(self): if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): diff --git a/tests/e2e/test_sglang.py b/tests/e2e/test_sglang.py index 6016e5b0..de9a5c1c 100644 --- a/tests/e2e/test_sglang.py +++ b/tests/e2e/test_sglang.py @@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel from llamafactory.extras.packages import is_sglang_available -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" +MODEL_NAME = "Qwen/Qwen2.5-0.5B" INFER_ARGS = {