mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
Merge branch 'main' of https://github.com/BUAADreamer/LLaMA-Factory
Former-commit-id: 43d7ad5eccfc04fd7f31a481b278a7101c64a2fa
This commit is contained in:
commit
f42c0b26d1
@ -10,6 +10,7 @@ from .base_engine import BaseEngine, Response
|
|||||||
|
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
@ -24,7 +25,8 @@ class VllmEngine(BaseEngine):
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
config = load_config(model_args) # may download model from ms hub
|
config = load_config(model_args) # may download model from ms hub
|
||||||
load_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
infer_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
|
infer_dtype = str(infer_dtype).split(".")[-1]
|
||||||
|
|
||||||
self.can_generate = finetuning_args.stage == "sft"
|
self.can_generate = finetuning_args.stage == "sft"
|
||||||
self.tokenizer = load_tokenizer(model_args)
|
self.tokenizer = load_tokenizer(model_args)
|
||||||
@ -36,15 +38,20 @@ class VllmEngine(BaseEngine):
|
|||||||
model=model_args.model_name_or_path,
|
model=model_args.model_name_or_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
download_dir=model_args.cache_dir,
|
download_dir=model_args.cache_dir,
|
||||||
dtype=str(load_dtype).split(".")[-1],
|
dtype=infer_dtype,
|
||||||
max_model_len=model_args.vllm_maxlen,
|
max_model_len=model_args.vllm_maxlen,
|
||||||
tensor_parallel_size=get_device_count() or 1,
|
tensor_parallel_size=get_device_count() or 1,
|
||||||
gpu_memory_utilization=model_args.vllm_gpu_util,
|
gpu_memory_utilization=model_args.vllm_gpu_util,
|
||||||
disable_log_stats=True,
|
disable_log_stats=True,
|
||||||
disable_log_requests=True,
|
disable_log_requests=True,
|
||||||
enforce_eager=model_args.vllm_enforce_eager,
|
enforce_eager=model_args.vllm_enforce_eager,
|
||||||
|
enable_lora=model_args.adapter_name_or_path is not None,
|
||||||
)
|
)
|
||||||
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
self.lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||||
|
else:
|
||||||
|
self.lora_request = None
|
||||||
|
|
||||||
async def _generate(
|
async def _generate(
|
||||||
self,
|
self,
|
||||||
@ -98,7 +105,11 @@ class VllmEngine(BaseEngine):
|
|||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
result_generator = self.model.generate(
|
result_generator = self.model.generate(
|
||||||
prompt=None, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_ids
|
prompt=None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
request_id=request_id,
|
||||||
|
prompt_token_ids=prompt_ids,
|
||||||
|
lora_request=self.lora_request,
|
||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
|
@ -308,15 +308,15 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
|||||||
if finetuning_args.stage != "sft":
|
if finetuning_args.stage != "sft":
|
||||||
raise ValueError("vLLM engine only supports auto-regressive models.")
|
raise ValueError("vLLM engine only supports auto-regressive models.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
|
||||||
raise ValueError("vLLM engine does not support LoRA adapters. Merge them first.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
raise ValueError("vLLM engine does not support quantization.")
|
raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).")
|
||||||
|
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.rope_scaling is not None:
|
||||||
raise ValueError("vLLM engine does not support RoPE scaling.")
|
raise ValueError("vLLM engine does not support RoPE scaling.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
|
raise ValueError("vLLM only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args)
|
_check_extra_dependencies(model_args, finetuning_args)
|
||||||
|
|
||||||
|
@ -31,7 +31,10 @@ class WebChatModel(ChatModel):
|
|||||||
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
||||||
model_name_or_path = os.environ.get("DEMO_MODEL")
|
model_name_or_path = os.environ.get("DEMO_MODEL")
|
||||||
template = os.environ.get("DEMO_TEMPLATE")
|
template = os.environ.get("DEMO_TEMPLATE")
|
||||||
super().__init__(dict(model_name_or_path=model_name_or_path, template=template))
|
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
|
||||||
|
super().__init__(
|
||||||
|
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loaded(self) -> bool:
|
def loaded(self) -> bool:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user