mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
[inference] support sglang backend (#7278)
* Mimic SGLang offline Engine * Add more tests and args * Pass all current tests * Clean Code * fix sample_params * clean code * Fix Stream Chat * change sglang from engine mode to server mode * fix * Fix Review Issues * Use SGLang Built-In Utilities * Fix test SGLang * Some Doc Issue * fix sglang engine * add readme --------- Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -302,7 +302,7 @@ class VllmArguments:
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
default=0.9,
|
||||
default=0.7,
|
||||
metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."},
|
||||
)
|
||||
vllm_enforce_eager: bool = field(
|
||||
@@ -324,7 +324,35 @@ class VllmArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments):
|
||||
class SGLangArguments:
|
||||
r"""Arguments pertaining to the SGLang worker."""
|
||||
|
||||
sglang_maxlen: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the SGLang engine."},
|
||||
)
|
||||
sglang_mem_fraction: float = field(
|
||||
default=0.7,
|
||||
metadata={"help": "The memory fraction (0-1) to be used for the SGLang engine."},
|
||||
)
|
||||
sglang_tp_size: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||
)
|
||||
sglang_config: Optional[Union[dict, str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.sglang_config, str) and self.sglang_config.startswith("{"):
|
||||
self.sglang_config = _convert_str_dict(json.loads(self.sglang_config))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments(
|
||||
SGLangArguments, VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments
|
||||
):
|
||||
r"""Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer.
|
||||
|
||||
The class on the most right will be displayed first.
|
||||
@@ -356,6 +384,7 @@ class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, Quantiz
|
||||
ProcessorArguments.__post_init__(self)
|
||||
ExportArguments.__post_init__(self)
|
||||
VllmArguments.__post_init__(self)
|
||||
SGLangArguments.__post_init__(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
|
||||
Reference in New Issue
Block a user