mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[infer] support lora adapter for SGLang backend (#8067)
This commit is contained in:
		
							parent
							
								
									52b23f9e56
								
							
						
					
					
						commit
						ab41f7956c
					
				@ -79,6 +79,10 @@ class SGLangEngine(BaseEngine):
 | 
			
		||||
        self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
 | 
			
		||||
        self.template.mm_plugin.expand_mm_tokens = False  # for sglang generate
 | 
			
		||||
        self.generating_args = generating_args.to_dict()
 | 
			
		||||
        if model_args.adapter_name_or_path is not None:
 | 
			
		||||
            self.lora_request = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.lora_request = False
 | 
			
		||||
 | 
			
		||||
        launch_cmd = [
 | 
			
		||||
            "python3 -m sglang.launch_server",
 | 
			
		||||
@ -90,6 +94,15 @@ class SGLangEngine(BaseEngine):
 | 
			
		||||
            f"--download-dir {model_args.cache_dir}",
 | 
			
		||||
            "--log-level error",
 | 
			
		||||
        ]
 | 
			
		||||
        if self.lora_request:
 | 
			
		||||
            launch_cmd.extend(
 | 
			
		||||
                [
 | 
			
		||||
                    "--max-loras-per-batch 1",
 | 
			
		||||
                    f"--lora-backend {model_args.sglang_lora_backend}",
 | 
			
		||||
                    f"--lora-paths lora0={model_args.adapter_name_or_path[0]}",
 | 
			
		||||
                    "--disable-radix-cache",
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
        launch_cmd = " ".join(launch_cmd)
 | 
			
		||||
        logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}")
 | 
			
		||||
        try:
 | 
			
		||||
@ -202,6 +215,8 @@ class SGLangEngine(BaseEngine):
 | 
			
		||||
                "sampling_params": sampling_params,
 | 
			
		||||
                "stream": True,
 | 
			
		||||
            }
 | 
			
		||||
            if self.lora_request:
 | 
			
		||||
                json_data["lora_request"] = ["lora0"]
 | 
			
		||||
            response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True)
 | 
			
		||||
            if response.status_code != 200:
 | 
			
		||||
                raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}")
 | 
			
		||||
 | 
			
		||||
@ -364,6 +364,12 @@ class SGLangArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
 | 
			
		||||
    )
 | 
			
		||||
    sglang_lora_backend: Literal["triton", "flashinfer"] = field(
 | 
			
		||||
        default="triton",
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "The backend of running GEMM kernels for Lora modules. Recommend using the Triton LoRA backend for better performance and stability."
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ from llamafactory.chat import ChatModel
 | 
			
		||||
from llamafactory.extras.packages import is_sglang_available
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
 | 
			
		||||
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
INFER_ARGS = {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user