diff --git a/README.md b/README.md index 092edc3f..c057a4c6 100644 --- a/README.md +++ b/README.md @@ -79,8 +79,8 @@ Choose your path: - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and PiSSA. - **Practical tricks**: [FlashAttention-2](https://github.com/Dao-AILab/flash-attention), [Unsloth](https://github.com/unslothai/unsloth), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), RoPE scaling, NEFTune and rsLoRA. - **Wide tasks**: Multi-turn dialogue, tool using, image understanding, visual grounding, video recognition, audio understanding, etc. -- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, SwanLab, etc. -- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. +- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, [SwanLab](https://github.com/SwanHubX/SwanLab), etc. +- **Faster inference**: OpenAI-style API, Gradio UI and CLI with [vLLM worker](https://github.com/vllm-project/vllm) or [SGLang worker](https://github.com/sgl-project/sglang). ### Day-N Support for Fine-Tuning Cutting-Edge Models @@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/03/15] We supported **[SGLang](https://github.com/sgl-project/sglang)** as inference backend. Try `infer_backend: sglang` to accelerate inference. + [25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model. [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. @@ -437,7 +439,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, sglang, galore, apollo, badam, adam-mini, qwen, minicpm_v, modelscope, openmind, swanlab, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. diff --git a/README_zh.md b/README_zh.md index d3a1c0c7..8b09090e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -81,8 +81,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 PiSSA。 - **实用技巧**:[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)、[Unsloth](https://github.com/unslothai/unsloth)、[Liger Kernel](https://github.com/linkedin/Liger-Kernel)、RoPE scaling、NEFTune 和 rsLoRA。 - **广泛任务**:多轮对话、工具调用、图像理解、视觉定位、视频识别和语音理解等等。 -- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、SwanLab 等等。 -- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 +- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow、[SwanLab](https://github.com/SwanHubX/SwanLab) 等等。 +- **极速推理**:基于 [vLLM](https://github.com/vllm-project/vllm) 或 [SGLang](https://github.com/sgl-project/sglang) 的 OpenAI 风格 API、浏览器界面和命令行接口。 ### 最新模型的 Day-N 微调适配 @@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/03/15] 我们支持了 **[SGLang](https://github.com/sgl-project/sglang)** 推理后端,请使用 `infer_backend: sglang` 启用。 + [25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。 [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 @@ -439,7 +441,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、sglang、galore、apollo、badam、adam-mini、qwen、minicpm_v、modelscope、openmind、swanlab、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 diff --git a/examples/inference/llama3_sglang.yaml b/examples/inference/llama3_sglang.yaml new file mode 100644 index 00000000..82418981 --- /dev/null +++ b/examples/inference/llama3_sglang.yaml @@ -0,0 +1,4 @@ +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +template: llama3 +infer_backend: sglang +trust_remote_code: true diff --git a/setup.py b/setup.py index d9b4a905..83528a45 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ extra_require = { "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], "vllm": ["vllm>=0.4.3,<=0.7.3"], + "sglang": ["sglang>=0.4.4"], "galore": ["galore-torch"], "apollo": ["apollo-torch"], "badam": ["badam>=1.2.1"], diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 8a604619..0022eed9 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -25,6 +25,7 @@ from ..extras.constants import EngineName from ..extras.misc import torch_gc from ..hparams import get_infer_args from .hf_engine import HuggingfaceEngine +from .sglang_engine import SGLangEngine from .vllm_engine import VllmEngine @@ -52,6 +53,8 @@ class ChatModel: self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) elif model_args.infer_backend == EngineName.VLLM: self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args) + elif model_args.infer_backend == EngineName.SGLANG: + self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args) else: raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 813c0976..ef39c417 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import concurrent.futures import os from collections.abc import AsyncGenerator from threading import Thread @@ -349,7 +348,6 @@ class HuggingfaceEngine(BaseEngine): if not self.can_generate: raise ValueError("The current model does not support `chat`.") - loop = asyncio.get_running_loop() input_args = ( self.model, self.tokenizer, @@ -365,8 +363,7 @@ class HuggingfaceEngine(BaseEngine): input_kwargs, ) async with self.semaphore: - with concurrent.futures.ThreadPoolExecutor() as pool: - return await loop.run_in_executor(pool, self._chat, *input_args) + return await asyncio.to_thread(self._chat, *input_args) @override async def stream_chat( @@ -382,7 +379,6 @@ class HuggingfaceEngine(BaseEngine): if not self.can_generate: raise ValueError("The current model does not support `stream_chat`.") - loop = asyncio.get_running_loop() input_args = ( self.model, self.tokenizer, @@ -398,13 +394,12 @@ class HuggingfaceEngine(BaseEngine): input_kwargs, ) async with self.semaphore: - with concurrent.futures.ThreadPoolExecutor() as pool: - stream = self._stream_chat(*input_args) - while True: - try: - yield await loop.run_in_executor(pool, stream) - except StopAsyncIteration: - break + stream = self._stream_chat(*input_args) + while True: + try: + yield await asyncio.to_thread(stream) + except StopAsyncIteration: + break @override async def get_scores( @@ -415,8 +410,6 @@ class HuggingfaceEngine(BaseEngine): if self.can_generate: raise ValueError("Cannot get scores using an auto-regressive model.") - loop = asyncio.get_running_loop() input_args = (self.model, self.tokenizer, batch_input, input_kwargs) async with self.semaphore: - with concurrent.futures.ThreadPoolExecutor() as pool: - return await loop.run_in_executor(pool, self._get_scores, *input_args) + return await asyncio.to_thread(self._get_scores, *input_args) diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py new file mode 100644 index 00000000..7bcb05d2 --- /dev/null +++ b/src/llamafactory/chat/sglang_engine.py @@ -0,0 +1,282 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import atexit +import json +from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, Optional, Union + +import requests +from typing_extensions import override + +from ..data import get_template_and_fix_tokenizer +from ..extras import logging +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName +from ..extras.misc import get_device_count, torch_gc +from ..extras.packages import is_sglang_available +from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments +from ..model import load_config, load_tokenizer +from ..model.model_utils.quantization import QuantizationMethod +from .base_engine import BaseEngine, Response + + +if is_sglang_available(): + from sglang.utils import launch_server_cmd, terminate_process, wait_for_server + + +if TYPE_CHECKING: + from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + + +logger = logging.get_logger(__name__) + + +class SGLangEngine(BaseEngine): + """Inference engine for SGLang models. + + This class wraps the SGLang engine to provide a consistent interface for text generation + that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for + better interaction and performance. The engine launches a server process and communicates + with it via HTTP requests. + + For more details on the SGLang HTTP server approach, see: + https://docs.sglang.ai/backend/send_request.html + """ + + def __init__( + self, + model_args: "ModelArguments", + data_args: "DataArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + ) -> None: + self.name = EngineName.SGLANG + self.model_args = model_args + config = load_config(model_args) # may download model from ms hub + if getattr(config, "quantization_config", None): # gptq models should use float16 + quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto": + model_args.infer_dtype = "float16" + + self.can_generate = finetuning_args.stage == "sft" + tokenizer_module = load_tokenizer(model_args) + self.tokenizer = tokenizer_module["tokenizer"] + self.processor = tokenizer_module["processor"] + self.tokenizer.padding_side = "left" + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) + self.template.mm_plugin.expand_mm_tokens = False # for sglang generate + self.generating_args = generating_args.to_dict() + + launch_cmd = [ + "python3 -m sglang.launch_server", + f"--model-path {model_args.model_name_or_path}", + f"--dtype {model_args.infer_dtype}", + f"--context-length {model_args.sglang_maxlen}", + f"--mem-fraction-static {model_args.sglang_mem_fraction}", + f"--tp-size {model_args.sglang_tp_size if model_args.sglang_tp_size != -1 else get_device_count() or 1}", + f"--download-dir {model_args.cache_dir}", + "--log-level error", + ] + launch_cmd = " ".join(launch_cmd) + logger.info_rank0(f"Starting SGLang server with command: {launch_cmd}") + try: + torch_gc() + self.server_process, port = launch_server_cmd(launch_cmd) + self.base_url = f"http://localhost:{port}" + atexit.register(self._cleanup_server) + + logger.info_rank0(f"Waiting for SGLang server to be ready at {self.base_url}") + wait_for_server(self.base_url, timeout=300) + logger.info_rank0(f"SGLang server initialized successfully at {self.base_url}") + try: + response = requests.get(f"{self.base_url}/get_model_info", timeout=5) + if response.status_code == 200: + model_info = response.json() + logger.info(f"SGLang server model info: {model_info}") + except Exception as e: + logger.debug(f"Note: could not get model info: {str(e)}") + + except Exception as e: + logger.error(f"Failed to start SGLang server: {str(e)}") + self._cleanup_server() # make sure to clean up any started process + raise RuntimeError(f"SGLang server initialization failed: {str(e)}.") + + def _cleanup_server(self): + r"""Clean up the server process when the engine is destroyed.""" + if hasattr(self, "server_process") and self.server_process: + try: + logger.info("Terminating SGLang server process") + terminate_process(self.server_process) + logger.info("SGLang server process terminated") + except Exception as e: + logger.warning(f"Error terminating SGLang server: {str(e)}") + + async def _generate( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncIterator[dict[str, Any]]: + mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} + if images is not None: + mm_input_dict.update({"images": images, "imglens": [len(images)]}) + if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] + + if videos is not None: + mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) + if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] + + if audios is not None: + mm_input_dict.update({"audios": audios, "audlens": [len(audios)]}) + if not any(AUDIO_PLACEHOLDER in message["content"] for message in messages): + messages[0]["content"] = AUDIO_PLACEHOLDER * len(audios) + messages[0]["content"] + + messages = self.template.mm_plugin.process_messages( + messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], self.processor + ) + paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or self.generating_args["default_system"] + prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) + prompt_length = len(prompt_ids) + + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, list[str]]] = input_kwargs.pop("stop", None) + + if num_return_sequences != 1: + raise NotImplementedError("SGLang only supports n=1.") + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 + + if max_length: + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 + + if max_new_tokens: + max_tokens = max_new_tokens + + sampling_params = { + "temperature": temperature if temperature is not None else self.generating_args["temperature"], + "top_p": (top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + "top_k": (top_k if top_k is not None else self.generating_args["top_k"]) or -1, # top_k must > 0 + "stop": stop, + "stop_token_ids": self.template.get_stop_token_ids(self.tokenizer), + "max_new_tokens": max_tokens, + "repetition_penalty": ( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + "skip_special_tokens": skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], + } + + def stream_request(): + json_data = { + "input_ids": prompt_ids, + "sampling_params": sampling_params, + "stream": True, + } + response = requests.post(f"{self.base_url}/generate", json=json_data, stream=True) + if response.status_code != 200: + raise RuntimeError(f"SGLang server error: {response.status_code}, {response.text}") + + for chunk in response.iter_lines(decode_unicode=False): + chunk = str(chunk.decode("utf-8")) + if chunk == "data: [DONE]": + break + + if chunk and chunk.startswith("data:"): + yield json.loads(chunk[5:].strip("\n")) + + return await asyncio.to_thread(stream_request) + + @override + async def chat( + self, + messages: Sequence[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, + audios: Optional[Sequence["AudioInput"]] = None, + **input_kwargs, + ) -> list["Response"]: + final_output = None + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for request_output in generator: + final_output = request_output + + results = [ + Response( + response_text=final_output["text"], + response_length=final_output["meta_info"]["completion_tokens"], + prompt_length=final_output["meta_info"]["prompt_tokens"], + finish_reason="stop" if final_output["meta_info"]["finish_reason"] == "stop" else "length", + ) + ] + return results + + @override + async def stream_chat( + self, + messages: list[dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, + **input_kwargs, + ) -> AsyncGenerator[str, None]: + generated_text = "" + generator = await self._generate(messages, system, tools, images, videos, audios, **input_kwargs) + for result in generator: + delta_text = result["text"][len(generated_text) :] + generated_text = result["text"] + yield delta_text + + @override + async def get_scores( + self, + batch_input: list[str], + **input_kwargs, + ) -> list[float]: + raise NotImplementedError("SGLang engine does not support `get_scores`.") + + def __del__(self): + r"""Ensure server is cleaned up when object is deleted.""" + self._cleanup_server() + try: + atexit.unregister(self._cleanup_server) + except Exception: + pass diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 4d37e81f..ef2405bc 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -252,4 +252,4 @@ class VllmEngine(BaseEngine): batch_input: list[str], **input_kwargs, ) -> list[float]: - raise NotImplementedError("vLLM engine does not support get_scores.") + raise NotImplementedError("vLLM engine does not support `get_scores`.") diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 4fd77d0a..a6bde0dd 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -106,6 +106,7 @@ class AttentionFunction(str, Enum): class EngineName(str, Enum): HF = "huggingface" VLLM = "vllm" + SGLANG = "sglang" class DownloadSource(str, Enum): diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 235b3043..6a56606a 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -274,3 +274,14 @@ def use_openmind() -> bool: def use_ray() -> bool: return is_env_enabled("USE_RAY") + + +def find_available_port() -> int: + """Find an available port on the local machine.""" + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + return port diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index b474633e..6b70f4ac 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -97,3 +97,7 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") + + +def is_sglang_available(): + return _is_package_available("sglang") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 4d7693e8..1cf2271b 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -302,7 +302,7 @@ class VllmArguments: metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, ) vllm_gpu_util: float = field( - default=0.9, + default=0.7, metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, ) vllm_enforce_eager: bool = field( @@ -324,7 +324,35 @@ class VllmArguments: @dataclass -class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments): +class SGLangArguments: + r"""Arguments pertaining to the SGLang worker.""" + + sglang_maxlen: int = field( + default=4096, + metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."}, + ) + sglang_mem_fraction: float = field( + default=0.7, + metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."}, + ) + sglang_tp_size: int = field( + default=-1, + metadata={"help": "Tensor parallel size for the SGLang engine."}, + ) + sglang_config: Optional[Union[dict, str]] = field( + default=None, + metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, + ) + + def __post_init__(self): + if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"): + self.sglang_config = _convert_str_dict(json.loads(self.sglang_config)) + + +@dataclass +class ModelArguments( + SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments +): r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. The class on the most right will be displayed first. @@ -356,6 +384,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz ProcessorArguments.__post_init__(self) ExportArguments.__post_init__(self) VllmArguments.__post_init__(self) + SGLangArguments.__post_init__(self) @classmethod def copyfrom(cls, source: "Self", **kwargs) -> "Self": diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 9ec819c7..0b3bf759 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -31,7 +31,7 @@ from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from ..extras import logging -from ..extras.constants import CHECKPOINT_NAMES +from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from .data_args import DataArguments from .evaluation_args import EvaluationArguments @@ -134,9 +134,12 @@ def _check_extra_dependencies( if model_args.mixture_of_depths is not None: check_version("mixture-of-depth>=1.1.6", mandatory=True) - if model_args.infer_backend == "vllm": + if model_args.infer_backend == EngineName.VLLM: check_version("vllm>=0.4.3,<=0.7.3") check_version("vllm", mandatory=True) + elif model_args.infer_backend == EngineName.SGLANG: + check_version("sglang>=0.4.4") + check_version("sglang", mandatory=True) if finetuning_args.use_galore: check_version("galore_torch", mandatory=True) diff --git a/src/llamafactory/webui/components/infer.py b/src/llamafactory/webui/components/infer.py index 5d0f4e88..677036b9 100644 --- a/src/llamafactory/webui/components/infer.py +++ b/src/llamafactory/webui/components/infer.py @@ -34,7 +34,7 @@ def create_infer_tab(engine: "Engine") -> dict[str, "Component"]: elem_dict = dict() with gr.Row(): - infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface") + infer_backend = gr.Dropdown(choices=["huggingface", "vllm", "sglang"], value="huggingface") infer_dtype = gr.Dropdown(choices=["auto", "float16", "bfloat16", "float32"], value="auto") with gr.Row(): diff --git a/tests/e2e/test_sglang.py b/tests/e2e/test_sglang.py new file mode 100644 index 00000000..6016e5b0 --- /dev/null +++ b/tests/e2e/test_sglang.py @@ -0,0 +1,71 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import pytest + +from llamafactory.chat import ChatModel +from llamafactory.extras.packages import is_sglang_available + + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + + +INFER_ARGS = { + "model_name_or_path": MODEL_NAME, + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", + "infer_backend": "sglang", + "do_sample": False, + "max_new_tokens": 1, +} + + +MESSAGES = [ + {"role": "user", "content": "Hi"}, +] + + +@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") +def test_chat(): + r"""Test the SGLang engine's basic chat functionality.""" + chat_model = ChatModel(INFER_ARGS) + response = chat_model.chat(MESSAGES)[0] + # TODO: Change to EXPECTED_RESPONSE + print(response.response_text) + + +@pytest.mark.skipif(not is_sglang_available(), reason="SGLang is not installed") +def test_stream_chat(): + r"""Test the SGLang engine's streaming chat functionality.""" + chat_model = ChatModel(INFER_ARGS) + + response = "" + for token in chat_model.stream_chat(MESSAGES): + response += token + + print("Complete response:", response) + assert response, "Should receive a non-empty response" + + +# Run tests if executed directly +if __name__ == "__main__": + if not is_sglang_available(): + print("SGLang is not available. Please install it.") + sys.exit(1) + + test_chat() + test_stream_chat()