mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
update api
Former-commit-id: f030b09924f0fb07305c244115759ac295e957c7
This commit is contained in:
parent
194f38df8f
commit
e4a869dc42
@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
max_length: Optional[int] = None
|
||||||
max_new_tokens: Optional[int] = None
|
max_new_tokens: Optional[int] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
"input_ids": inputs["input_ids"],
|
"input_ids": inputs["input_ids"],
|
||||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||||
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
|
|
||||||
"logits_processor": get_logits_processor()
|
"logits_processor": get_logits_processor()
|
||||||
})
|
})
|
||||||
|
if request.max_length:
|
||||||
|
gen_kwargs.pop("max_new_tokens", None)
|
||||||
|
gen_kwargs["max_length"] = request.max_length
|
||||||
|
if request.max_new_tokens:
|
||||||
|
gen_kwargs.pop("max_length", None)
|
||||||
|
gen_kwargs["max_new_tokens"] = request.max_new_tokens
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(gen_kwargs, request.model)
|
generate = predict(gen_kwargs, request.model)
|
||||||
|
@ -171,8 +171,8 @@ def load_pretrained(
|
|||||||
padding_side="left",
|
padding_side="left",
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
@ -277,6 +277,10 @@ class GeneratingArguments:
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||||
)
|
)
|
||||||
|
max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||||
|
)
|
||||||
max_new_tokens: Optional[int] = field(
|
max_new_tokens: Optional[int] = field(
|
||||||
default=512,
|
default=512,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||||
@ -291,4 +295,7 @@ class GeneratingArguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
args = asdict(self)
|
||||||
|
if args.get("max_new_tokens", None):
|
||||||
|
args.pop("max_length", None)
|
||||||
|
return args
|
||||||
|
Loading…
x
Reference in New Issue
Block a user