mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
fix chat engines
do not use pop(key, default) since api assigns None to dict values Former-commit-id: d52fae2fa866afeb6156dc98388ce5cc6d5eca77
This commit is contained in:
parent
6955042c10
commit
864da49139
@ -2,7 +2,7 @@ import asyncio
|
||||
import concurrent.futures
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
@ -66,16 +66,16 @@ class HuggingfaceEngine(BaseEngine):
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
|
||||
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
|
||||
@ -83,20 +83,23 @@ class HuggingfaceEngine(BaseEngine):
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
dict(
|
||||
do_sample=do_sample,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature if temperature is not None else generating_args["temperature"],
|
||||
top_p=top_p if top_p is not None else generating_args["top_p"],
|
||||
top_k=top_k if top_k is not None else generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None
|
||||
else generating_args["repetition_penalty"],
|
||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
|
||||
generating_args["do_sample"] = True
|
||||
generating_args["temperature"] = generating_args["temperature"] or 1.0
|
||||
|
||||
if not generating_args["temperature"]:
|
||||
generating_args["do_sample"] = False
|
||||
|
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
@ -102,18 +102,25 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search = self.generating_args["num_beams"] > 1
|
||||
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
|
||||
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
|
||||
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
|
||||
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
|
||||
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
|
||||
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if "max_new_tokens" in self.generating_args:
|
||||
max_tokens = self.generating_args["max_new_tokens"]
|
||||
elif "max_length" in self.generating_args:
|
||||
if self.generating_args["max_length"] > prompt_length:
|
||||
max_tokens = self.generating_args["max_length"] - prompt_length
|
||||
else:
|
||||
max_tokens = 1
|
||||
|
||||
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
|
||||
if max_length:
|
||||
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
|
||||
|
||||
@ -122,12 +129,15 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=num_return_sequences,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repetition_penalty=(
|
||||
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
|
||||
)
|
||||
or 1.0, # repetition_penalty must > 0
|
||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=max_tokens,
|
||||
|
@ -68,8 +68,8 @@ class Template:
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user