[api] support repetition_penalty and align presence_penalty with OpenAI Client (#7958)

This commit is contained in:
wangzhan 2025-05-26 18:45:11 +08:00 committed by GitHub
parent 52dead8775
commit c477ae6405
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View File

@ -186,6 +186,7 @@ async def create_chat_completion_response(
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, images, videos, audios = _process_request(request)
repetition_penalty = request.presence_penalty if request.presence_penalty else request.repetition_penalty
responses = await chat_model.achat(
input_messages,
system,
@ -199,6 +200,7 @@ async def create_chat_completion_response(
max_new_tokens=request.max_tokens,
num_return_sequences=request.n,
stop=request.stop,
repetition_penalty=repetition_penalty,
)
prompt_length, response_length = 0, 0
@ -248,6 +250,7 @@ async def create_stream_chat_completion_response(
yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="")
)
repetition_penalty = request.presence_penalty if request.presence_penalty else request.repetition_penalty
async for new_token in chat_model.astream_chat(
input_messages,
system,
@ -260,6 +263,7 @@ async def create_stream_chat_completion_response(
top_p=request.top_p,
max_new_tokens=request.max_tokens,
stop=request.stop,
repetition_penalty=repetition_penalty,
):
if len(new_token) != 0:
yield _create_stream_chat_completion_chunk(

View File

@ -106,6 +106,8 @@ class ChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
stream: bool = False
presence_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
class ChatCompletionResponseChoice(BaseModel):