From 18656a6316e2849bb04289234857d2f729c9bef7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 19 Jul 2023 00:01:14 +0800 Subject: [PATCH] fix API Former-commit-id: 29af67b015ff92e5dd9bf2985ce7723dc036d989 --- requirements.txt | 4 ++-- src/llmtuner/__init__.py | 2 +- src/llmtuner/api/app.py | 7 +++---- src/llmtuner/tuner/ppo/trainer.py | 6 +++++- src/llmtuner/tuner/rm/trainer.py | 2 +- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index e7f5bf16..12e907e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ rouge-chinese nltk gradio>=3.36.0 uvicorn -pydantic -fastapi +pydantic==1.10.11 +fastapi==0.95.1 sse-starlette matplotlib diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index b53af4de..fde74f57 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -4,4 +4,4 @@ from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokeni from llmtuner.webui import create_ui -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 2a07527a..63b4fd10 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,4 +1,3 @@ -import json import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -96,7 +95,7 @@ def create_app(): finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield json.dumps(chunk, ensure_ascii=False) + yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in chat_model.stream_chat( query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens @@ -110,7 +109,7 @@ def create_app(): finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield json.dumps(chunk, ensure_ascii=False) + yield chunk.json(exclude_unset=True, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( index=0, @@ -118,7 +117,7 @@ def create_app(): finish_reason=Finish.STOP ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) - yield json.dumps(chunk, ensure_ascii=False) + yield chunk.json(exclude_unset=True, ensure_ascii=False) yield "[DONE]" return app diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 0d84af6e..668d2f44 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -107,7 +107,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Compute rewards replace_model(unwrapped_model, target="reward") with torch.no_grad(): - _, _, values = self.model(**self.prepare_model_inputs(queries, responses)) + _, _, values = self.model( + **self.prepare_model_inputs(queries, responses), + output_hidden_states=True, + return_dict=True + ) rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type replace_model(unwrapped_model, target="default") diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 199fecf4..749fa68d 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -32,7 +32,7 @@ class PairwisePeftTrainer(PeftTrainer): See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 """ batch_size = inputs["input_ids"].size(0) // 2 - _, _, values = model(**inputs) + _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() return (loss, [loss, r_accept, r_reject]) if return_outputs else loss