Former-commit-id: 29af67b015ff92e5dd9bf2985ce7723dc036d989
This commit is contained in:
hiyouga 2023-07-19 00:01:14 +08:00
parent af37ac077c
commit 18656a6316
5 changed files with 12 additions and 9 deletions

View File

@ -10,7 +10,7 @@ rouge-chinese
nltk nltk
gradio>=3.36.0 gradio>=3.36.0
uvicorn uvicorn
pydantic pydantic==1.10.11
fastapi fastapi==0.95.1
sse-starlette sse-starlette
matplotlib matplotlib

View File

@ -4,4 +4,4 @@ from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokeni
from llmtuner.webui import create_ui from llmtuner.webui import create_ui
__version__ = "0.1.0" __version__ = "0.1.1"

View File

@ -1,4 +1,3 @@
import json
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -96,7 +95,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
@ -110,7 +109,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -118,7 +117,7 @@ def create_app():
finish_reason=Finish.STOP finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"
return app return app

View File

@ -107,7 +107,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards # Compute rewards
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
with torch.no_grad(): with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")

View File

@ -32,7 +32,7 @@ class PairwisePeftTrainer(PeftTrainer):
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss