diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index c97197de..8228d588 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -186,6 +186,7 @@ async def create_chat_completion_response( ) -> "ChatCompletionResponse": completion_id = f"chatcmpl-{uuid.uuid4().hex}" input_messages, system, tools, images, videos, audios = _process_request(request) + repetition_penalty = request.presence_penalty if request.presence_penalty else request.repetition_penalty responses = await chat_model.achat( input_messages, system, @@ -199,6 +200,7 @@ async def create_chat_completion_response( max_new_tokens=request.max_tokens, num_return_sequences=request.n, stop=request.stop, + repetition_penalty=repetition_penalty, ) prompt_length, response_length = 0, 0 @@ -248,6 +250,7 @@ async def create_stream_chat_completion_response( yield _create_stream_chat_completion_chunk( completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") ) + repetition_penalty = request.presence_penalty if request.presence_penalty else request.repetition_penalty async for new_token in chat_model.astream_chat( input_messages, system, @@ -260,6 +263,7 @@ async def create_stream_chat_completion_response( top_p=request.top_p, max_new_tokens=request.max_tokens, stop=request.stop, + repetition_penalty=repetition_penalty, ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index ac9746ef..2c5520ea 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -106,6 +106,8 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stop: Optional[Union[str, list[str]]] = None stream: bool = False + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None class ChatCompletionResponseChoice(BaseModel):