diff --git a/README.md b/README.md index d9a97c53..14941517 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ ## Changelog -[23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/baichuan-inc/baichuan-7B) for details. +[23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details. [23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications. diff --git a/src/api_demo.py b/src/api_demo.py index 5f8d99d6..fd1d450a 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -10,6 +10,7 @@ import uvicorn from threading import Thread from pydantic import BaseModel, Field from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from transformers import TextIteratorStreamer from starlette.responses import StreamingResponse @@ -34,6 +35,15 @@ async def lifespan(app: FastAPI): # collects GPU memory app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + class ModelCard(BaseModel): id: str object: str = "model"