update requirements

Former-commit-id: 66a91e1fe39483b83c7636c8199c8a87cf6a599e
This commit is contained in:
hiyouga 2023-11-06 19:01:21 +08:00
parent 5c19786f7c
commit 936297aeac
4 changed files with 8 additions and 9 deletions

View File

@ -1,9 +1,10 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.31.0 transformers>=4.31.0,<4.35.0
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.7.2 trl>=0.7.2
gradio>=3.38.0,<4.0.0
scipy scipy
sentencepiece sentencepiece
protobuf protobuf
@ -12,9 +13,8 @@ fire
jieba jieba
rouge-chinese rouge-chinese
nltk nltk
gradio==3.50.2
uvicorn uvicorn
pydantic==1.10.11 pydantic
fastapi==0.95.1 fastapi
sse-starlette sse-starlette
matplotlib matplotlib

View File

@ -32,9 +32,9 @@ async def lifespan(app: FastAPI): # collects GPU memory
def to_json(data: BaseModel) -> str: def to_json(data: BaseModel) -> str:
try: try: # pydantic v2
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
except: except: # pydantic v1
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)

View File

@ -14,7 +14,6 @@ from transformers import (
PreTrainedTokenizerBase PreTrainedTokenizerBase
) )
from transformers.models.llama import modeling_llama as LlamaModule from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -39,7 +38,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
check_min_version("4.31.0") require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")

View File

@ -14,7 +14,7 @@ from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine from llmtuner.webui.engine import Engine
require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2") require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
def create_ui() -> gr.Blocks: def create_ui() -> gr.Blocks: