diff --git a/README.md b/README.md index f9ea4303..f5fd8345 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,9 @@ ## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory -Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet) +Preview LLaMA Board at **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**. + +Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet) Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU. @@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 > > For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. -Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported. +Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported. ## Supported Training Approaches @@ -79,9 +81,9 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| Reward Modeling | | | :white_check_mark: | :white_check_mark: | -| PPO Training | | | :white_check_mark: | :white_check_mark: | -| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: | +| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | > [!NOTE] > Use `--quantization_bit 4/8` argument to enable QLoRA. @@ -122,6 +124,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) diff --git a/README_zh.md b/README_zh.md index 74b9362a..42aa70fc 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,7 +14,9 @@ ## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory -使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练) +通过 **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。 + +使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该界面目前仅支持单卡训练) 下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。 @@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 > > 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。 -项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。 +项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。 ## 训练方法 @@ -79,9 +81,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | -| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: | -| PPO 训练 | | | :white_check_mark: | :white_check_mark: | -| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: | +| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | +| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | > [!NOTE] > 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。 @@ -122,6 +124,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) +- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) diff --git a/data/belle_multiturn/belle_multiturn.py b/data/belle_multiturn/belle_multiturn.py index 816a38bf..4b42527a 100644 --- a/data/belle_multiturn/belle_multiturn.py +++ b/data/belle_multiturn/belle_multiturn.py @@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder): def _info(self): features = datasets.Features({ - "instruction": datasets.Value("string"), - "output": datasets.Value("string"), - "history": datasets.Sequence(datasets.Sequence(datasets.Value("string"))) + "conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}] }) return datasets.DatasetInfo( description=_DESCRIPTION, @@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder): with open(filepath, "r", encoding="utf-8") as f: for key, row in enumerate(f): data = json.loads(row) + conversations = [] prompt = data["instruction"].strip() response = data["output"].strip() @@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder): human_idx = prompt.rfind("Human:") query = prompt[human_idx+6:assist_idx].strip() prompt = prompt[:human_idx].strip() - history = [] + conversations.insert(0, {"from": "gpt", "value": response}) + conversations.insert(0, {"from": "human", "value": query}) while prompt.rfind("Assistant:") != -1: assist_idx = prompt.rfind("Assistant:") @@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder): if human_idx != -1: old_query = prompt[human_idx+6:assist_idx].strip() old_resp = prompt[assist_idx+10:].strip() - history.insert(0, (old_query, old_resp)) + conversations.insert(0, {"from": "gpt", "value": old_resp}) + conversations.insert(0, {"from": "human", "value": old_query}) else: break prompt = prompt[:human_idx].strip() - yield key, { - "instruction": query, - "output": response, - "history": history - } + yield key, {"conversations": conversations} diff --git a/data/dataset_info.json b/data/dataset_info.json index 7a7557bc..33d5dd12 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -88,11 +88,7 @@ }, "belle_multiturn": { "script_url": "belle_multiturn", - "columns": { - "prompt": "instruction", - "response": "output", - "history": "history" - } + "formatting": "sharegpt" }, "ultra_chat": { "script_url": "ultra_chat", @@ -107,6 +103,13 @@ "alpaca_cot": { "hf_hub_url": "QingyiSi/Alpaca-CoT" }, + "openorca": { + "hf_hub_url": "Open-Orca/OpenOrca", + "columns": { + "prompt": "question", + "response": "response" + } + }, "mathinstruct": { "hf_hub_url": "TIGER-Lab/MathInstruct", "columns": { diff --git a/data/ultra_chat/ultra_chat.py b/data/ultra_chat/ultra_chat.py index c187abb2..df6a23fb 100644 --- a/data/ultra_chat/ultra_chat.py +++ b/data/ultra_chat/ultra_chat.py @@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder): "from": "human" if i % 2 == 0 else "gpt", "value": content[i] } for i in range(len(content))] - yield key, { - "conversations": conversations - } + yield key, {"conversations": conversations} diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 895a2c48..c0778bca 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,9 +1,9 @@ -# Level: api, webui > chat, eval > tuner > dsets > extras, hparams +# Level: api, webui > chat, eval, train > data, model > extras, hparams from llmtuner.api import create_app from llmtuner.chat import ChatModel from llmtuner.eval import Evaluator -from llmtuner.tuner import export_model, run_exp +from llmtuner.train import export_model, run_exp from llmtuner.webui import create_ui, create_web_demo diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 27fb19e0..c01fa0df 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,14 +1,8 @@ import json -import uvicorn -from fastapi import FastAPI, HTTPException, status -from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager -from sse_starlette import EventSourceResponse from typing import List, Tuple from pydantic import BaseModel +from contextlib import asynccontextmanager -from llmtuner.extras.misc import torch_gc -from llmtuner.chat import ChatModel from llmtuner.api.protocol import ( Role, Finish, @@ -23,10 +17,28 @@ from llmtuner.api.protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionResponseUsage ) +from llmtuner.chat import ChatModel +from llmtuner.extras.misc import torch_gc +from llmtuner.extras.packages import ( + is_fastapi_availble, is_starlette_available, is_uvicorn_available +) + + +if is_fastapi_availble(): + from fastapi import FastAPI, HTTPException, status + from fastapi.middleware.cors import CORSMiddleware + + +if is_starlette_available(): + from sse_starlette import EventSourceResponse + + +if is_uvicorn_available(): + import uvicorn @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: "FastAPI"): # collects GPU memory yield torch_gc() @@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str: return data.json(exclude_unset=True, ensure_ascii=False) -def create_app(chat_model: ChatModel) -> FastAPI: +def create_app(chat_model: "ChatModel") -> "FastAPI": app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) async def create_chat_completion(request: ChatCompletionRequest): - if len(request.messages) < 1 or request.messages[-1].role != Role.USER: + if len(request.messages) == 0 or request.messages[-1].role != Role.USER: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] - if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: + if len(prev_messages) and prev_messages[0].role == Role.SYSTEM: system = prev_messages.pop(0).content else: system = None @@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI: history.append([prev_messages[i].content, prev_messages[i+1].content]) else: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") if request.stream: generate = predict(query, history, system, request) return EventSourceResponse(generate, media_type="text/event-stream") - response, (prompt_length, response_length) = chat_model.chat( + responses = chat_model.chat( query, history, system, do_sample=request.do_sample, temperature=request.temperature, @@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI: num_return_sequences=request.n ) + prompt_length, response_length = 0, 0 + choices = [] + for i, response in enumerate(responses): + choices.append(ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), + finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + )) + prompt_length = response.prompt_length + response_length += response.response_length + usage = ChatCompletionResponseUsage( prompt_tokens=prompt_length, completion_tokens=response_length, total_tokens=prompt_length+response_length ) - choices = [ChatCompletionResponseChoice( - index=i, - message=ChatMessage(role=Role.ASSISTANT, content=choice), - finish_reason=Finish.STOP - ) for i, choice in enumerate(response)] - return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py index ba240d05..f86efe96 100644 --- a/src/llmtuner/chat/__init__.py +++ b/src/llmtuner/chat/__init__.py @@ -1 +1 @@ -from llmtuner.chat.stream_chat import ChatModel +from llmtuner.chat.chat_model import ChatModel diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/chat_model.py similarity index 72% rename from src/llmtuner/chat/stream_chat.py rename to src/llmtuner/chat/chat_model.py index cc815d1b..9966a813 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/chat_model.py @@ -1,11 +1,21 @@ import torch -from typing import Any, Dict, Generator, List, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from threading import Thread from transformers import GenerationConfig, TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor -from llmtuner.extras.template import get_template_and_fix_tokenizer -from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer +from llmtuner.data.template import get_template_and_fix_tokenizer +from llmtuner.extras.misc import get_logits_processor +from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer + + +@dataclass +class Response: + + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] class ChatModel: @@ -18,7 +28,7 @@ class ChatModel: self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.system_prompt = data_args.system_prompt - def process_args( + def _process_args( self, query: str, history: Optional[List[Tuple[str, str]]] = None, @@ -79,17 +89,30 @@ class ChatModel: history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, **input_kwargs - ) -> Tuple[List[str], Tuple[int, int]]: - gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) + ) -> List[Response]: + r""" + Args: query, history, system, **input_kwargs + + Returns: [(response_text, prompt_length, response_length)] * n (default n=1) + """ + gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs) generate_output = self.model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] - response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) - response_length = 0 - for i in range(len(response_ids)): + response = self.tokenizer.batch_decode( + response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + results = [] + for i in range(len(response)): eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() - response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append(Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length" + )) - return response, (prompt_length, response_length) + return results @torch.inference_mode() def stream_chat( @@ -99,7 +122,7 @@ class ChatModel: system: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) + gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py new file mode 100644 index 00000000..35f7caa3 --- /dev/null +++ b/src/llmtuner/data/__init__.py @@ -0,0 +1,4 @@ +from llmtuner.data.loader import get_dataset +from llmtuner.data.preprocess import preprocess_dataset +from llmtuner.data.template import get_template_and_fix_tokenizer +from llmtuner.data.utils import split_dataset diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/data/loader.py similarity index 99% rename from src/llmtuner/dsets/loader.py rename to src/llmtuner/data/loader.py index 98d495e9..b2a64075 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/data/loader.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset -from llmtuner.dsets.utils import checksum, EXT2TYPE +from llmtuner.data.utils import checksum, EXT2TYPE from llmtuner.extras.logging import get_logger if TYPE_CHECKING: diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/data/preprocess.py similarity index 99% rename from src/llmtuner/dsets/preprocess.py rename to src/llmtuner/data/preprocess.py index 1554345f..2d2b2db6 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un from datasets import load_from_disk +from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger -from llmtuner.extras.template import get_template_and_fix_tokenizer if TYPE_CHECKING: from datasets import Dataset, IterableDataset diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/data/template.py similarity index 86% rename from src/llmtuner/extras/template.py rename to src/llmtuner/data/template.py index bcb9ffa0..03b3c011 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/data/template.py @@ -225,9 +225,6 @@ def get_template_and_fix_tokenizer( return template -r""" -Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff -""" register_template( name="alpaca", prefix=[ @@ -246,11 +243,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/BAAI/AquilaChat-7B - https://huggingface.co/BAAI/AquilaChat2-7B - https://huggingface.co/BAAI/AquilaChat2-34B -""" register_template( name="aquila", prefix=[ @@ -273,9 +265,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat -""" register_template( name="baichuan", prefix=[ @@ -292,10 +281,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat - https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat -""" register_template( name="baichuan2", prefix=[ @@ -312,9 +297,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B -""" register_template( name="belle", prefix=[ @@ -330,9 +312,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat -""" register_template( name="bluelm", prefix=[ @@ -348,9 +327,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/THUDM/chatglm2-6b -""" register_template( name="chatglm2", prefix=[ @@ -369,9 +345,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/THUDM/chatglm3-6b -""" register_template( name="chatglm3", prefix=[ @@ -395,11 +368,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct - https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct - https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct -""" register_template( name="deepseek", prefix=[ @@ -426,9 +394,6 @@ register_template( ) -r""" -Default template. -""" register_template( name="default", prefix=[ @@ -447,9 +412,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/tiiuae/falcon-180B-chat -""" register_template( name="falcon", prefix=[ @@ -466,10 +428,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/internlm/internlm-chat-7b - https://huggingface.co/internlm/internlm-chat-20b -""" register_template( name="intern", prefix=[ @@ -492,11 +450,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf - https://huggingface.co/meta-llama/Llama-2-13b-chat-hf - https://huggingface.co/meta-llama/Llama-2-70b-chat-hf -""" register_template( name="llama2", prefix=[ @@ -519,10 +472,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b - https://huggingface.co/ziqingyang/chinese-alpaca-2-13b -""" register_template( name="llama2_zh", prefix=[ @@ -536,9 +485,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 -""" register_template( name="mistral", prefix=[ @@ -552,9 +498,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/openchat/openchat_3.5 -""" register_template( name="openchat", prefix=[ @@ -576,10 +519,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/Qwen/Qwen-7B-Chat - https://huggingface.co/Qwen/Qwen-14B-Chat -""" register_template( name="qwen", prefix=[ @@ -606,10 +545,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha - https://huggingface.co/HuggingFaceH4/starchat-beta -""" register_template( name="starchat", prefix=[ @@ -650,10 +585,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5 - https://huggingface.co/lmsys/vicuna-13b-v1.5 -""" register_template( name="vicuna", prefix=[ @@ -670,10 +601,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/xverse/XVERSE-7B-Chat - https://huggingface.co/xverse/XVERSE-13B-Chat -""" register_template( name="xverse", prefix=[ @@ -687,11 +614,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/wenge-research/yayi-7b - https://huggingface.co/wenge-research/yayi-7b-llama2 - https://huggingface.co/wenge-research/yayi-13b-llama2 -""" register_template( name="yayi", prefix=[ @@ -724,10 +646,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha - https://huggingface.co/HuggingFaceH4/zephyr-7b-beta -""" register_template( name="zephyr", prefix=[ @@ -746,11 +664,6 @@ register_template( ) -r""" -Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 - https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1 - https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat -""" register_template( name="ziya", prefix=[ diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/data/utils.py similarity index 100% rename from src/llmtuner/dsets/utils.py rename to src/llmtuner/data/utils.py diff --git a/src/llmtuner/dsets/__init__.py b/src/llmtuner/dsets/__init__.py deleted file mode 100644 index cccbd745..00000000 --- a/src/llmtuner/dsets/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from llmtuner.dsets.loader import get_dataset -from llmtuner.dsets.preprocess import preprocess_dataset -from llmtuner.dsets.utils import split_dataset diff --git a/src/llmtuner/eval/__init__.py b/src/llmtuner/eval/__init__.py index 10584817..a7c9a127 100644 --- a/src/llmtuner/eval/__init__.py +++ b/src/llmtuner/eval/__init__.py @@ -1 +1 @@ -from llmtuner.eval.engine import Evaluator +from llmtuner.eval.evaluator import Evaluator diff --git a/src/llmtuner/eval/constants.py b/src/llmtuner/eval/constants.py deleted file mode 100644 index 433ad39b..00000000 --- a/src/llmtuner/eval/constants.py +++ /dev/null @@ -1,3 +0,0 @@ -CHOICES = ["A", "B", "C", "D"] - -SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] diff --git a/src/llmtuner/eval/engine.py b/src/llmtuner/eval/evaluator.py similarity index 95% rename from src/llmtuner/eval/engine.py rename to src/llmtuner/eval/evaluator.py index 10dff844..b2e04bec 100644 --- a/src/llmtuner/eval/engine.py +++ b/src/llmtuner/eval/evaluator.py @@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional from datasets import load_dataset from transformers.utils import cached_file -from llmtuner.eval.constants import CHOICES, SUBJECTS -from llmtuner.eval.parser import get_eval_args +from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.eval.template import get_eval_template -from llmtuner.extras.misc import dispatch_model -from llmtuner.extras.template import get_template_and_fix_tokenizer -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.extras.constants import CHOICES, SUBJECTS +from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer class Evaluator: diff --git a/src/llmtuner/eval/parser.py b/src/llmtuner/eval/parser.py deleted file mode 100644 index cef38048..00000000 --- a/src/llmtuner/eval/parser.py +++ /dev/null @@ -1,49 +0,0 @@ -import transformers -from typing import Any, Dict, Optional, Tuple -from transformers import HfArgumentParser - -from llmtuner.extras.misc import parse_args -from llmtuner.hparams import ( - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -) - - -def parse_eval_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments - )) - return parse_args(parser, args) - - -def get_eval_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -]: - model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) - - if data_args.template is None: - raise ValueError("Please specify which `template` to use.") - - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") - - transformers.set_seed(eval_args.seed) - - return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index 44cb3c6d..2251ad57 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Tuple -from llmtuner.eval.constants import CHOICES +from llmtuner.extras.constants import CHOICES if TYPE_CHECKING: from datasets import Dataset diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 7398d424..5cf62cdc 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl + from trl import AutoModelForCausalLMWithValueHead logger = get_logger(__name__) @@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback): """ if args.should_save: output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) - model = kwargs.pop("model") + model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model") + model.pretrained_model.config.save_pretrained(output_dir) + if model.pretrained_model.can_generate(): + model.pretrained_model.generation_config.save_pretrained(output_dir) if getattr(model, "is_peft_model", False): - getattr(model, "pretrained_model").save_pretrained(output_dir) + model.pretrained_model.save_pretrained(output_dir) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of training. """ if args.should_save: - model = kwargs.pop("model") + model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model") + model.pretrained_model.config.save_pretrained(args.output_dir) + if model.pretrained_model.can_generate(): + model.pretrained_model.generation_config.save_pretrained(args.output_dir) if getattr(model, "is_peft_model", False): - getattr(model, "pretrained_model").save_pretrained(args.output_dir) + model.pretrained_model.save_pretrained(args.output_dir) class LogCallback(TrainerCallback): diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 95916b69..cc9ac290 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict from typing import Dict, Optional +CHOICES = ["A", "B", "C", "D"] + +DEFAULT_MODULE = defaultdict(str) + +DEFAULT_TEMPLATE = defaultdict(str) + IGNORE_INDEX = -100 +LAYERNORM_NAMES = {"norm", "ln"} + LOG_FILE_NAME = "trainer_log.jsonl" METHODS = ["full", "freeze", "lora"] +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + +SUPPORTED_MODELS = OrderedDict() + TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", @@ -16,14 +28,6 @@ TRAINING_STAGES = { "Pre-Training": "pt" } -LAYERNORM_NAMES = {"norm", "ln"} - -SUPPORTED_MODELS = OrderedDict() - -DEFAULT_MODULE = defaultdict(str) - -DEFAULT_TEMPLATE = defaultdict(str) - def register_model_group( models: Dict[str, str], @@ -116,10 +120,12 @@ register_model_group( register_model_group( models={ - "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", - "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", - "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", - "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b" + "ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b", + "ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b", + "ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b", + "ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b", + "ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b", + "ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b" }, template="llama2_zh" ) @@ -190,6 +196,14 @@ register_model_group( ) +register_model_group( + models={ + "OpenChat3.5-7B-Chat": "openchat/openchat_3.5" + }, + template="openchat" +) + + register_model_group( models={ "Phi1.5-1.3B": "microsoft/phi-1_5" @@ -217,6 +231,15 @@ register_model_group( ) +register_model_group( + models={ + "Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5", + "Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5" + }, + template="vicuna" +) + + register_model_group( models={ "XVERSE-7B": "xverse/XVERSE-7B", @@ -229,9 +252,27 @@ register_model_group( ) +register_model_group( + models={ + "Yayi-7B": "wenge-research/yayi-7b-llama2", + "Yayi-13B": "wenge-research/yayi-13b-llama2" + }, + template="yayi" +) + + register_model_group( models={ "Yi-6B": "01-ai/Yi-6B", "Yi-34B": "01-ai/Yi-34B" } ) + + +register_model_group( + models={ + "Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha", + "Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta" + }, + template="zephyr" +) diff --git a/src/llmtuner/extras/logging.py b/src/llmtuner/extras/logging.py index d6f185e6..d01c14a4 100644 --- a/src/llmtuner/extras/logging.py +++ b/src/llmtuner/extras/logging.py @@ -3,6 +3,9 @@ import logging class LoggerHandler(logging.Handler): + r""" + Logger handler used in Web UI. + """ def __init__(self): super().__init__() @@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler): self.log += "\n\n" -def reset_logging(): - r""" - Removes basic config of root logger - """ - root = logging.getLogger() - list(map(root.removeHandler, root.handlers)) - list(map(root.removeFilter, root.filters)) - - def get_logger(name: str) -> logging.Logger: + r""" + Gets a standard logger with a stream hander to stdout. + """ formatter = logging.Formatter( fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" @@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger: logger.addHandler(handler) return logger + + +def reset_logging() -> None: + r""" + Removes basic config of root logger. (unused in script) + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 6300bc75..6a906c74 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -13,14 +13,13 @@ try: is_torch_npu_available ) _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() - _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available + _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available() except ImportError: _is_fp16_available = torch.cuda.is_available() _is_bf16_available = torch.cuda.is_bf16_supported() if TYPE_CHECKING: from transformers import HfArgumentParser - from transformers.modeling_utils import PreTrainedModel class AverageMeter: @@ -65,16 +64,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param -def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: - r""" - Infers the optimal dtype according to the model_dtype and device compatibility. - """ - if _is_bf16_available and model_dtype == torch.bfloat16: - return torch.bfloat16 - elif _is_fp16_available: - return torch.float16 +def get_current_device() -> str: + import accelerate + from accelerate import Accelerator + dummy_accelerator = Accelerator() + if accelerate.utils.is_xpu_available(): + return "xpu:{}".format(dummy_accelerator.local_process_index) else: - return torch.float32 + return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu" def get_logits_processor() -> "LogitsProcessorList": @@ -86,14 +83,16 @@ def get_logits_processor() -> "LogitsProcessorList": return logits_processor -def torch_gc() -> None: +def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: r""" - Collects GPU memory. + Infers the optimal dtype according to the model_dtype and device compatibility. """ - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: @@ -107,26 +106,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None return parser.parse_args_into_dataclasses() -def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": +def torch_gc() -> None: r""" - Dispatches a pre-trained model to GPUs with balanced memory. - Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + Collects GPU memory. """ - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing - return model - - if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - from accelerate.utils import infer_auto_device_map, get_balanced_memory - - if model._no_split_modules is None: - raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") - - kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} - max_memory = get_balanced_memory(model, **kwargs) - # Make sure tied weights are tied before creating the device map. - model.tie_weights() - device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) - return dispatch_model(model, device_map) - else: - return model.cuda() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py new file mode 100644 index 00000000..26df247b --- /dev/null +++ b/src/llmtuner/extras/packages.py @@ -0,0 +1,55 @@ +import importlib.metadata +import importlib.util + + +def is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def get_package_version(name: str) -> str: + try: + return importlib.metadata.version(name) + except: + return "0.0.0" + + +_fastapi_available = is_package_available("fastapi") +_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") +_jieba_available = is_package_available("jieba") +_matplotlib_available = is_package_available("matplotlib") +_nltk_available = is_package_available("nltk") +_rouge_available = is_package_available("rouge-chinese") +_starlette_available = is_package_available("sse-starlette") +_uvicorn_available = is_package_available("uvicorn") + + +def is_fastapi_availble(): + return _fastapi_available + + +def is_flash_attn2_available(): + return _flash_attn2_available + + +def is_jieba_available(): + return _jieba_available + + +def is_matplotlib_available(): + return _matplotlib_available + + +def is_nltk_available(): + return _nltk_available + + +def is_rouge_available(): + return _rouge_available + + +def is_starlette_available(): + return _starlette_available + + +def is_uvicorn_available(): + return _uvicorn_available diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index bf3e5d57..1fb7ed3b 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -3,16 +3,19 @@ import torch import torch.nn as nn from typing import Optional, Tuple from transformers.utils import logging -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv - -is_flash_attn_2_available = False +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb try: + from transformers.models.llama.modeling_llama import repeat_kv +except ImportError: + print("Please upgrade `transformers`.") + +from llmtuner.extras.packages import is_flash_attn2_available + + +if is_flash_attn2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore from flash_attn.bert_padding import pad_input, unpad_input # type: ignore - is_flash_attn_2_available = True -except ImportError: - is_flash_attn_2_available = False logger = logging.get_logger(__name__) diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index 82530e45..cf2c72ac 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -1,11 +1,14 @@ import os import math import json -import matplotlib.pyplot as plt from typing import List, Optional from transformers.trainer import TRAINER_STATE_NAME from llmtuner.extras.logging import get_logger +from llmtuner.extras.packages import is_matplotlib_available + +if is_matplotlib_available(): + import matplotlib.pyplot as plt logger = get_logger(__name__) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index eb949626..cfdc8b24 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field @dataclass -class FinetuningArguments: +class FreezeArguments: r""" - Arguments pertaining to which techniques we are going to fine-tuning with. + Arguments pertaining to the freeze (partial-parameter) training. """ - stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( - default="sft", - metadata={"help": "Which stage will be performed in training."} - ) - finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( - default="lora", - metadata={"help": "Which fine-tuning method to use."} - ) num_layer_trainable: Optional[int] = field( default=3, metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} ) - name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( + name_module_trainable: Optional[str] = field( default="mlp", metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ + Use commas to separate multiple modules. \ LLaMA choices: [\"mlp\", \"self_attn\"], \ BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ Qwen choices: [\"mlp\", \"attn\"], \ Phi-1.5 choices: [\"mlp\", \"mixer\"], \ - LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} + Others choices: the same as LLaMA."} ) + + +@dataclass +class LoraArguments: + r""" + Arguments pertaining to the LoRA training. + """ lora_rank: Optional[int] = field( default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} ) lora_alpha: Optional[float] = field( - default=32.0, - metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."} ) lora_dropout: Optional[float] = field( default=0.1, @@ -49,7 +49,7 @@ class FinetuningArguments: Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ - LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} + Others choices: the same as LLaMA."} ) additional_target: Optional[str] = field( default=None, @@ -59,30 +59,76 @@ class FinetuningArguments: default=True, metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} ) - ppo_score_norm: Optional[bool] = field( - default=False, - metadata={"help": "Use score normalization in PPO training."} + + +@dataclass +class RLHFArguments: + r""" + Arguments pertaining to the PPO and DPO training. + """ + dpo_beta: Optional[float] = field( + default=0.1, + metadata={"help": "The beta parameter for the DPO loss."} ) ppo_logger: Optional[str] = field( default=None, metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} ) + ppo_score_norm: Optional[bool] = field( + default=False, + metadata={"help": "Use score normalization in PPO training."} + ) ppo_target: Optional[float] = field( default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."} ) - dpo_beta: Optional[float] = field( - default=0.1, - metadata={"help": "The beta parameter for the DPO loss."} + ppo_whiten_rewards: Optional[bool] = field( + default=False, + metadata={"help": "Whiten the rewards before compute advantages in PPO training."} ) - dpo_ref_model: Optional[str] = field( + ref_model: Optional[str] = field( default=None, - metadata={"help": "Path to the reference model used for the DPO training."} + metadata={"help": "Path to the reference model used for the PPO or DPO training."} ) - dpo_ref_model_checkpoint: Optional[str] = field( + ref_model_checkpoint: Optional[str] = field( default=None, metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."} ) + ref_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reference model."} + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the reward model."} + ) + reward_model_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."} + ) + reward_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reward model."} + ) + reward_model_type: Optional[Literal["lora", "full"]] = field( + default="lora", + metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."} + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) + finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."} + ) upcast_layernorm: Optional[bool] = field( default=False, metadata={"help": "Whether to upcast the layernorm weights in fp32."} @@ -91,15 +137,37 @@ class FinetuningArguments: default=0, metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) def __post_init__(self): - if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA - self.lora_target = [target.strip() for target in self.lora_target.split(",")] + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg - if isinstance(self.additional_target, str): - self.additional_target = [target.strip() for target in self.additional_target.split(",")] + self.name_module_trainable = split_arg(self.name_module_trainable) + self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0) + self.lora_target = split_arg(self.lora_target) + self.additional_target = split_arg(self.additional_target) + self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint) + self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint) assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." + assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.stage == "ppo" and self.reward_model is None: + raise ValueError("Reward model is necessary for PPO training.") + + if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": + raise ValueError("Lora reward model only supports lora training.") def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" @@ -112,4 +180,5 @@ class FinetuningArguments: r"""Creates an instance from the content of `json_path`.""" with open(json_path, "r", encoding="utf-8") as f: text = f.read() + return cls(**json.loads(text)) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 4b17c272..4bda39d5 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -54,22 +54,10 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} ) - reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments - default=None, - metadata={"help": "Path to the directory containing the checkpoints of the reward model."} - ) - plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments - default=False, - metadata={"help": "Whether to plot the training loss after fine-tuning or not."} - ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) - export_dir: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory to save the exported model."} - ) def __post_init__(self): self.compute_dtype = None @@ -81,8 +69,7 @@ class ModelArguments: if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] - if self.quantization_bit is not None: - assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." def to_dict(self) -> Dict[str, Any]: return asdict(self) diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py new file mode 100644 index 00000000..fb9a05e7 --- /dev/null +++ b/src/llmtuner/model/__init__.py @@ -0,0 +1,5 @@ +# Level: loader > adapter > parser, utils + +from llmtuner.model.loader import load_model_and_tokenizer +from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args +from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/model/adapter.py similarity index 69% rename from src/llmtuner/tuner/core/adapter.py rename to src/llmtuner/model/adapter.py index d3799f24..b7fe78a0 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,18 +1,9 @@ -import os import torch from typing import TYPE_CHECKING - -from transformers.utils import cached_file -from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME -from peft import ( - PeftModel, - TaskType, - LoraConfig, - get_peft_model -) +from peft import PeftModel, TaskType, LoraConfig, get_peft_model from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core.utils import find_all_linear_modules +from llmtuner.model.utils import find_all_linear_modules if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -46,13 +37,23 @@ def init_adapter( if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") - num_layers = getattr(model.config, "num_layers") + num_layers = ( + getattr(model.config, "num_hidden_layers", None) + or getattr(model.config, "num_layers", None) + or getattr(model.config, "n_layer", None) + ) + if not num_layers: + raise ValueError("Current model does not support freeze tuning.") if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] else: # fine-tuning the first n layers if num_layer_trainable < 0 trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] - trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] + trainable_layers = [] + for module_name in finetuning_args.name_module_trainable: + for idx in trainable_layer_ids: + trainable_layers.append("{:d}.{}".format(idx, module_name)) + for name, param in model.named_parameters(): if not any(trainable_layer in name for trainable_layer in trainable_layers): param.requires_grad_(False) @@ -100,30 +101,3 @@ def init_adapter( logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) return model - - -def load_valuehead_params( - model: "PreTrainedModel", - model_args: "ModelArguments" -) -> bool: - kwargs = { - "path_or_repo_id": model_args.reward_model, - "cache_dir": model_args.cache_dir, - "token": model_args.hf_hub_token, - "revision": model_args.model_revision - } - try: - vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) - except: - try: - vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) - except: - logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model)) - return False - - vhead_params = torch.load(vhead_file, map_location="cpu") - model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) - model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) - model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) - model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) - return True diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/model/loader.py similarity index 84% rename from src/llmtuner/tuner/core/loader.py rename to src/llmtuner/model/loader.py index 38d5f71e..20b9b5d4 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/model/loader.py @@ -15,7 +15,6 @@ from transformers import ( ) from transformers.models.llama import modeling_llama as LlamaModule from transformers.utils.versions import require_version -from peft import PeftModel from trl import AutoModelForCausalLMWithValueHead try: @@ -24,11 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from transformers.deepspeed import is_deepspeed_zero3_enabled from llmtuner.extras.logging import reset_logging, get_logger -from llmtuner.extras.misc import count_parameters, infer_optim_dtype +from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype +from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments -from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params -from llmtuner.tuner.core.utils import prepare_model_for_training +from llmtuner.model.adapter import init_adapter +from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -73,6 +73,7 @@ def load_model_and_tokenizer( ) if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: + logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.") model_to_load = model_args.checkpoint_dir[0] else: model_to_load = model_args.model_name_or_path @@ -122,7 +123,7 @@ def load_model_and_tokenizer( # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": - if LlamaPatches.is_flash_attn_2_available: + if is_flash_attn2_available(): LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask logger.info("Using FlashAttention-2 for faster training and inference.") @@ -131,7 +132,7 @@ def load_model_and_tokenizer( elif getattr(config, "model_type", None) in ["qwen", "Yi"]: logger.info("Current model automatically enables FlashAttention if installed.") else: - logger.warning("Current model does not support FlashAttention-2.") + logger.warning("Current model does not support FlashAttention.") elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention logger.warning("Using `--flash_attn` for faster training in large context length.") @@ -144,7 +145,7 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support shift short attention.") - # Quantization configurations (using bitsandbytes library). + # Quantization configurations (using bitsandbytes library) if model_args.quantization_bit is not None: if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") @@ -164,10 +165,10 @@ def load_model_and_tokenizer( bnb_4bit_quant_type=model_args.quantization_type ) - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" + config_kwargs["device_map"] = {"": get_current_device()} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - # Load and prepare pre-trained models (without valuehead). + # Load pre-trained models (without valuehead) model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, @@ -185,7 +186,7 @@ def load_model_and_tokenizer( setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) - # Register auto class to save the custom code files. + # Register auto class to save the custom code files if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): config.__class__.register_for_auto_class() if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}): @@ -199,25 +200,15 @@ def load_model_and_tokenizer( model = model.train() if is_trainable else model.eval() # Prepare model with valuehead for RLHF - if stage == "rm" or stage == "ppo": + if stage in ["rm", "ppo"]: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) - reset_logging() - if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model - logger.warning("Only the last checkpoint containing valuehead will be loaded.") - if load_valuehead_params(model, model_args): - model.v_head.load_state_dict({ - "summary.weight": getattr(model, "reward_head_weight"), - "summary.bias": getattr(model, "reward_head_bias") - }) - - if stage == "ppo": # load reward model - logger.info("Load reward model from {}".format(model_args.reward_model)) - if isinstance(model.pretrained_model, PeftModel): - model.pretrained_model.load_adapter(model_args.reward_model, "reward") - for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 - if "default" in name: - param.data = param.data.to(torch.float32) # trainable params should in fp32 - assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded." + vhead_path = ( + model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path + ) + vhead_params = load_valuehead_params(vhead_path, model_args) + if vhead_params is not None: + model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) # Prepare model for inference if not is_trainable: diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/model/parser.py similarity index 78% rename from src/llmtuner/tuner/core/parser.py rename to src/llmtuner/model/parser.py index 04fc884b..051978b8 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/model/parser.py @@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args from llmtuner.hparams import ( ModelArguments, DataArguments, + EvaluationArguments, FinetuningArguments, GeneratingArguments ) @@ -19,51 +20,42 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def parse_train_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments - )) +_TRAIN_ARGS = [ + ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments +] +_TRAIN_CLS = Tuple[ + ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments +] +_INFER_ARGS = [ + ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +] +_INFER_CLS = Tuple[ + ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +] +_EVAL_ARGS = [ + ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments +] +_EVAL_CLS = Tuple[ + ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments +] + + +def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = HfArgumentParser(_TRAIN_ARGS) return parse_args(parser, args) -def parse_infer_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments - )) +def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + parser = HfArgumentParser(_INFER_ARGS) return parse_args(parser, args) -def get_train_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments -]: +def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + parser = HfArgumentParser(_EVAL_ARGS) + return parse_args(parser, args) + + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) # Setup logging @@ -90,24 +82,19 @@ def get_train_args( raise ValueError("Please enable `predict_with_generate` to save model predictions.") if finetuning_args.stage in ["rm", "ppo"]: - if finetuning_args.finetuning_type != "lora": - raise ValueError("RM and PPO stages can only be performed with the LoRA method.") if training_args.resume_from_checkpoint is not None: raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") if training_args.load_best_model_at_end: raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") if finetuning_args.stage == "ppo" and not training_args.do_train: - raise ValueError("PPO training does not support evaluation.") + raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") if finetuning_args.stage in ["rm", "dpo"]: for dataset_attr in data_args.dataset_list: if not dataset_attr.ranking: raise ValueError("Please use ranked datasets for reward modeling or DPO training.") - if finetuning_args.stage == "ppo" and model_args.reward_model is None: - raise ValueError("Reward model is necessary for PPO training.") - if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") @@ -139,6 +126,9 @@ def get_train_args( if (not training_args.do_train) and model_args.quantization_bit is not None: logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: + logger.warning("Specify `ref_model` for computing rewards at evaluation.") + # postprocess training_args if ( training_args.local_rank != -1 @@ -187,14 +177,7 @@ def get_train_args( return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments -]: +def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) if data_args.template is None: @@ -211,3 +194,17 @@ def get_infer_args( raise ValueError("Only LoRA tuning accepts multiple checkpoints.") return model_args, data_args, finetuning_args, generating_args + + +def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + transformers.set_seed(eval_args.seed) + + return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/model/utils.py similarity index 66% rename from src/llmtuner/tuner/core/utils.py rename to src/llmtuner/model/utils.py index 5e56513c..7badc905 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,21 +1,53 @@ import torch from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from transformers.utils import cached_file +from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME + from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.logging import get_logger +from llmtuner.hparams import ModelArguments, FinetuningArguments if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import DataArguments logger = get_logger(__name__) +def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": + r""" + Dispatches a pre-trained model to GPUs with balanced memory. + Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + """ + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing + return model + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + + if model._no_split_modules is None: + raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") + + kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} + max_memory = get_balanced_memory(model, **kwargs) + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) + return dispatch_model(model, device_map) + else: + return model.cuda() + + def find_all_linear_modules( model: "PreTrainedModel", quantization_bit: Optional[int] = None ) -> List[str]: + r""" + Finds all available modules to apply lora. + """ if quantization_bit is not None: import bitsandbytes as bnb linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt @@ -51,6 +83,32 @@ def generate_model_card( } +def load_valuehead_params( + path_or_repo_id: str, + model_args: "ModelArguments" +) -> Dict[str, torch.Tensor]: + r""" + Loads value head parameters from Hugging Face Hub or local disk. + + Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. + """ + kwargs = { + "path_or_repo_id": path_or_repo_id, + "cache_dir": model_args.cache_dir, + "token": model_args.hf_hub_token + } + try: + vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs) + except: + try: + vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs) + except: + logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id)) + return None + + return torch.load(vhead_file, map_location="cpu") + + def prepare_model_for_training( model: "PreTrainedModel", finetuning_args: "FinetuningArguments", diff --git a/src/llmtuner/train/__init__.py b/src/llmtuner/train/__init__.py new file mode 100644 index 00000000..e57c163b --- /dev/null +++ b/src/llmtuner/train/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.tuner import export_model, run_exp diff --git a/src/llmtuner/train/dpo/__init__.py b/src/llmtuner/train/dpo/__init__.py new file mode 100644 index 00000000..96c8ed09 --- /dev/null +++ b/src/llmtuner/train/dpo/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/dpo/collator.py b/src/llmtuner/train/dpo/collator.py similarity index 100% rename from src/llmtuner/tuner/dpo/collator.py rename to src/llmtuner/train/dpo/collator.py diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py similarity index 89% rename from src/llmtuner/tuner/dpo/trainer.py rename to src/llmtuner/train/dpo/trainer.py index c2b0b581..ccf49a7f 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -43,7 +43,11 @@ class CustomDPOTrainer(DPOTrainer): if ref_model is not None: if self.is_deepspeed_enabled: - self.ref_model = self._prepare_deepspeed(self.ref_model) + if not ( + getattr(ref_model, "is_loaded_in_8bit", False) + or getattr(ref_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py similarity index 68% rename from src/llmtuner/tuner/dpo/workflow.py rename to src/llmtuner/train/dpo/workflow.py index 240d34c5..5281f4e4 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -4,23 +4,20 @@ from peft import PeftModel from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.logging import get_logger from llmtuner.extras.ploting import plot_loss from llmtuner.hparams import ModelArguments -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding -from llmtuner.tuner.dpo.trainer import CustomDPOTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.utils import create_ref_model +from llmtuner.train.dpo.collator import DPODataCollatorWithPadding +from llmtuner.train.dpo.trainer import CustomDPOTrainer if TYPE_CHECKING: from transformers import TrainerCallback from llmtuner.hparams import DataArguments, FinetuningArguments -logger = get_logger(__name__) - - def run_dpo( model_args: "ModelArguments", data_args: "DataArguments", @@ -38,23 +35,10 @@ def run_dpo( ) # Create reference model - if finetuning_args.dpo_ref_model is not None: - ref_model_args_dict = model_args.to_dict() - ref_model_args_dict.update(dict( - model_name_or_path=finetuning_args.dpo_ref_model, - checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint - )) - ref_model_args = ModelArguments(**ref_model_args_dict) - ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft") - logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model)) - elif training_args.do_train: - if isinstance(model, PeftModel): - ref_model = None - else: - ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") - logger.info("Created reference model from the model itself.") - else: + if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself ref_model = model + else: + ref_model = create_ref_model(model_args, finetuning_args, stage="dpo") # Update arguments training_args_dict = training_args.to_dict() @@ -80,14 +64,13 @@ def run_dpo( trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - if trainer.is_world_process_zero() and model_args.plot_loss: + if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") if id(model) == id(ref_model): # unable to compute rewards without a reference model - logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.") remove_keys = [key for key in metrics.keys() if "rewards" in key] for key in remove_keys: metrics.pop(key) diff --git a/src/llmtuner/train/ppo/__init__.py b/src/llmtuner/train/ppo/__init__.py new file mode 100644 index 00000000..c32b23fa --- /dev/null +++ b/src/llmtuner/train/ppo/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.ppo.workflow import run_ppo diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py similarity index 90% rename from src/llmtuner/tuner/ppo/trainer.py rename to src/llmtuner/train/ppo/trainer.py index 3d591615..949e2ce8 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -3,7 +3,7 @@ import sys import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model +from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -37,24 +37,43 @@ class CustomPPOTrainer(PPOTrainer, Trainer): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: List["TrainerCallback"], + reward_model: "AutoModelForCausalLMWithValueHead", **kwargs ): PPOTrainer.__init__(self, **kwargs) + self.args = training_args self.model_args = model_args self.finetuning_args = finetuning_args + self.reward_model = reward_model + self.generation_config = GenerationConfig( pad_token_id=self.tokenizer.pad_token_id, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, **generating_args.to_dict() ) + self.state = TrainerState() self.control = TrainerControl() self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) + if self.args.max_steps > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") + if reward_model is not None: + is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) + if is_deepspeed_enabled: + if not ( + getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.reward_model = self._prepare_deepspeed(self.reward_model) + else: + self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) + def ppo_train(self) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. @@ -213,11 +232,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer): r""" Computes scores using given reward model. """ - replace_model(unwrapped_model, target="reward") + if self.reward_model is None: + replace_model(unwrapped_model, target="reward") + batch = self.prepare_model_inputs(queries, responses) with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 - _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) + reward_model = self.reward_model if self.reward_model is not None else self.model + _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 values = torch.transpose(values, 0, 1) @@ -228,7 +250,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): end_index = end_indexes[-1].item() if len(end_indexes) else 0 rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type - replace_model(unwrapped_model, target="default") + if self.reward_model is None: + replace_model(unwrapped_model, target="default") + return rewards @PPODecorators.empty_device_cache() diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/train/ppo/utils.py similarity index 100% rename from src/llmtuner/tuner/ppo/utils.py rename to src/llmtuner/train/ppo/utils.py diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py similarity index 83% rename from src/llmtuner/tuner/ppo/workflow.py rename to src/llmtuner/train/ppo/workflow.py index 9e5a5979..41a99e2c 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorWithPadding from transformers.optimization import get_scheduler -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.data import get_dataset, preprocess_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer -from llmtuner.tuner.ppo.trainer import CustomPPOTrainer +from llmtuner.model import load_model_and_tokenizer +from llmtuner.train.utils import create_ref_model, create_reward_model +from llmtuner.train.ppo.trainer import CustomPPOTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -33,6 +34,11 @@ def run_ppo( tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + # Create reference model and reward model + ref_model = create_ref_model(model_args, finetuning_args, stage="ppo") + reward_model = create_reward_model(model, model_args, finetuning_args) + + # Create ppo config ppo_config = PPOConfig( model_name=model_args.model_name_or_path, learning_rate=training_args.learning_rate, @@ -47,9 +53,11 @@ def run_ppo( log_with=finetuning_args.ppo_logger, use_score_scaling=finetuning_args.ppo_score_norm, use_score_norm=finetuning_args.ppo_score_norm, + whiten_rewards=finetuning_args.ppo_whiten_rewards, accelerator_kwargs={"step_scheduler_with_optimizer": False} ) + # Create optimizer and scheduler optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) if training_args.max_steps > 0: num_training_steps = training_args.max_steps @@ -73,9 +81,10 @@ def run_ppo( finetuning_args=finetuning_args, generating_args=generating_args, callbacks=callbacks + [SavePeftModelCallback()], + reward_model=reward_model, config=ppo_config, model=model, - ref_model=None, + ref_model=ref_model, tokenizer=tokenizer, dataset=dataset, data_collator=data_collator, @@ -88,5 +97,5 @@ def run_ppo( ppo_trainer.ppo_train() ppo_trainer.save_model() ppo_trainer.save_state() # must be called after save_model to have a folder - if ppo_trainer.is_world_process_zero() and model_args.plot_loss: + if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "reward"]) diff --git a/src/llmtuner/train/pt/__init__.py b/src/llmtuner/train/pt/__init__.py new file mode 100644 index 00000000..eacbeadb --- /dev/null +++ b/src/llmtuner/train/pt/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.pt.workflow import run_pt diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/train/pt/workflow.py similarity index 91% rename from src/llmtuner/tuner/pt/workflow.py rename to src/llmtuner/train/pt/workflow.py index ab0e0206..41bf31ba 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -4,9 +4,9 @@ import math from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForLanguageModeling, Trainer -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer +from llmtuner.model import generate_model_card, load_model_and_tokenizer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -42,7 +42,7 @@ def run_pt( trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - if trainer.is_world_process_zero() and model_args.plot_loss: + if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation diff --git a/src/llmtuner/train/rm/__init__.py b/src/llmtuner/train/rm/__init__.py new file mode 100644 index 00000000..c80ccfb9 --- /dev/null +++ b/src/llmtuner/train/rm/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.rm.workflow import run_rm diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/train/rm/collator.py similarity index 100% rename from src/llmtuner/tuner/rm/collator.py rename to src/llmtuner/train/rm/collator.py diff --git a/src/llmtuner/tuner/rm/metric.py b/src/llmtuner/train/rm/metric.py similarity index 100% rename from src/llmtuner/tuner/rm/metric.py rename to src/llmtuner/train/rm/metric.py diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/train/rm/trainer.py similarity index 100% rename from src/llmtuner/tuner/rm/trainer.py rename to src/llmtuner/train/rm/trainer.py diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/train/rm/workflow.py similarity index 87% rename from src/llmtuner/tuner/rm/workflow.py rename to src/llmtuner/train/rm/workflow.py index 3e59c5c6..06f39702 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.rm.metric import compute_accuracy -from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding -from llmtuner.tuner.rm.trainer import PairwiseTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding +from llmtuner.train.rm.metric import compute_accuracy +from llmtuner.train.rm.trainer import PairwiseTrainer if TYPE_CHECKING: from transformers import TrainerCallback @@ -51,7 +51,7 @@ def run_rm( trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - if trainer.is_world_process_zero() and model_args.plot_loss: + if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation diff --git a/src/llmtuner/train/sft/__init__.py b/src/llmtuner/train/sft/__init__.py new file mode 100644 index 00000000..cb5448f4 --- /dev/null +++ b/src/llmtuner/train/sft/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.sft.workflow import run_sft diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/train/sft/metric.py similarity index 86% rename from src/llmtuner/tuner/sft/metric.py rename to src/llmtuner/train/sft/metric.py index 812896ee..18db0b88 100644 --- a/src/llmtuner/tuner/sft/metric.py +++ b/src/llmtuner/train/sft/metric.py @@ -2,15 +2,23 @@ import numpy as np from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union -import jieba -from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction - from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.packages import ( + is_jieba_available, is_nltk_available, is_rouge_available +) if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer +if is_jieba_available(): + import jieba + +if is_nltk_available(): + from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction + +if is_rouge_available(): + from rouge_chinese import Rouge + @dataclass class ComputeMetrics: diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/train/sft/trainer.py similarity index 100% rename from src/llmtuner/tuner/sft/trainer.py rename to src/llmtuner/train/sft/trainer.py diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/train/sft/workflow.py similarity index 92% rename from src/llmtuner/tuner/sft/workflow.py rename to src/llmtuner/train/sft/workflow.py index ef902fe7..8a802c9b 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.sft.metric import ComputeMetrics -from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.sft.metric import ComputeMetrics +from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer if TYPE_CHECKING: from transformers import TrainerCallback @@ -69,7 +69,7 @@ def run_sft( trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() - if trainer.is_world_process_zero() and model_args.plot_loss: + if trainer.is_world_process_zero() and finetuning_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) # Evaluation diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/train/tuner.py similarity index 80% rename from src/llmtuner/tuner/tune.py rename to src/llmtuner/train/tuner.py index 4eb7f78f..361dafb8 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/train/tuner.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer -from llmtuner.tuner.pt import run_pt -from llmtuner.tuner.sft import run_sft -from llmtuner.tuner.rm import run_rm -from llmtuner.tuner.ppo import run_ppo -from llmtuner.tuner.dpo import run_dpo +from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer +from llmtuner.train.pt import run_pt +from llmtuner.train.sft import run_sft +from llmtuner.train.rm import run_rm +from llmtuner.train.ppo import run_ppo +from llmtuner.train.dpo import run_dpo if TYPE_CHECKING: from transformers import TrainerCallback @@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional model_args, _, finetuning_args, _ = get_infer_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model.config.use_cache = True - model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) + model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size) try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" - tokenizer.save_pretrained(model_args.export_dir) + tokenizer.save_pretrained(finetuning_args.export_dir) except: logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py new file mode 100644 index 00000000..f41c7cc7 --- /dev/null +++ b/src/llmtuner/train/utils.py @@ -0,0 +1,80 @@ +import torch +from typing import TYPE_CHECKING, Literal, Union + +from llmtuner.extras.logging import get_logger +from llmtuner.hparams import ModelArguments, FinetuningArguments +from llmtuner.model import load_model_and_tokenizer, load_valuehead_params + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + from trl import AutoModelForCausalLMWithValueHead + + +logger = get_logger(__name__) + + +def create_ref_model( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + stage: Literal["ppo", "dpo"] +) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: + r""" + Creates reference model for PPO/DPO training. Evaluation mode is not supported. + + The valuehead parameter is randomly initialized since it is useless for PPO training. + """ + if finetuning_args.ref_model is not None: + ref_model_args_dict = model_args.to_dict() + ref_model_args_dict.update(dict( + model_name_or_path=finetuning_args.ref_model, + checkpoint_dir=finetuning_args.ref_model_checkpoint, + quantization_bit=finetuning_args.ref_model_quantization_bit + )) + ref_model_args = ModelArguments(**ref_model_args_dict) + ref_finetuning_args = FinetuningArguments(finetuning_type="lora") + ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage) + logger.info("Created reference model from {}".format(finetuning_args.ref_model)) + else: + if finetuning_args.finetuning_type == "lora": + ref_model = None + else: + ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage) + logger.info("Created reference model from the model itself.") + + return ref_model + + +def create_reward_model( + model: "AutoModelForCausalLMWithValueHead", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments" +) -> "AutoModelForCausalLMWithValueHead": + r""" + Creates reward model for PPO training. + """ + if finetuning_args.reward_model_type == "lora": + model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") + for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 + if "default" in name: + param.data = param.data.to(torch.float32) # trainable params should in fp32 + vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) + assert vhead_params is not None, "Reward model is not correctly loaded." + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) + model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) + logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) + return None + else: + reward_model_args_dict = model_args.to_dict() + reward_model_args_dict.update(dict( + model_name_or_path=finetuning_args.reward_model, + checkpoint_dir=finetuning_args.reward_model_checkpoint, + quantization_bit=finetuning_args.reward_model_quantization_bit + )) + reward_model_args = ModelArguments(**reward_model_args_dict) + reward_finetuning_args = FinetuningArguments(finetuning_type="lora") + reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo") + logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model)) + logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") + return reward_model diff --git a/src/llmtuner/tuner/__init__.py b/src/llmtuner/tuner/__init__.py deleted file mode 100644 index 4d5a83e4..00000000 --- a/src/llmtuner/tuner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.tune import export_model, run_exp diff --git a/src/llmtuner/tuner/core/__init__.py b/src/llmtuner/tuner/core/__init__.py deleted file mode 100644 index ac621f7c..00000000 --- a/src/llmtuner/tuner/core/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from llmtuner.tuner.core.parser import get_train_args, get_infer_args -from llmtuner.tuner.core.loader import load_model_and_tokenizer -from llmtuner.tuner.core.utils import generate_model_card diff --git a/src/llmtuner/tuner/dpo/__init__.py b/src/llmtuner/tuner/dpo/__init__.py deleted file mode 100644 index f2b5cfb5..00000000 --- a/src/llmtuner/tuner/dpo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/ppo/__init__.py b/src/llmtuner/tuner/ppo/__init__.py deleted file mode 100644 index 11519bab..00000000 --- a/src/llmtuner/tuner/ppo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.ppo.workflow import run_ppo diff --git a/src/llmtuner/tuner/pt/__init__.py b/src/llmtuner/tuner/pt/__init__.py deleted file mode 100644 index 8ce509db..00000000 --- a/src/llmtuner/tuner/pt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.pt.workflow import run_pt diff --git a/src/llmtuner/tuner/rm/__init__.py b/src/llmtuner/tuner/rm/__init__.py deleted file mode 100644 index 54d3d943..00000000 --- a/src/llmtuner/tuner/rm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.rm.workflow import run_rm diff --git a/src/llmtuner/tuner/sft/__init__.py b/src/llmtuner/tuner/sft/__init__.py deleted file mode 100644 index 493dd1a7..00000000 --- a/src/llmtuner/tuner/sft/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.sft.workflow import run_sft diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 57eadb01..92f4bcb1 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -2,7 +2,7 @@ import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple -from llmtuner.chat.stream_chat import ChatModel +from llmtuner.chat import ChatModel from llmtuner.extras.misc import torch_gc from llmtuner.hparams import GeneratingArguments from llmtuner.webui.common import get_save_dir @@ -14,14 +14,24 @@ if TYPE_CHECKING: class WebChatModel(ChatModel): - def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None: + def __init__( + self, + manager: "Manager", + demo_mode: Optional[bool] = False, + lazy_init: Optional[bool] = True + ) -> None: self.manager = manager + self.demo_mode = demo_mode self.model = None self.tokenizer = None self.generating_args = GeneratingArguments() - if not lazy_init: + + if not lazy_init: # read arguments from command line super().__init__() + if demo_mode: # load openchat 3.5 by default + super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat")) + @property def loaded(self) -> bool: return self.model is not None @@ -36,6 +46,8 @@ class WebChatModel(ChatModel): error = ALERTS["err_no_model"][lang] elif not get("top.model_path"): error = ALERTS["err_no_path"][lang] + elif self.demo_mode: + error = ALERTS["err_demo"][lang] if error: gr.Warning(error) @@ -67,6 +79,11 @@ class WebChatModel(ChatModel): def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: lang = data[self.manager.get_elem_by_name("top.lang")] + + if self.demo_mode: + yield ALERTS["err_demo"][lang] + return + yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 6663254c..55d8942b 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -70,7 +70,7 @@ def get_module(model_name: str) -> str: def get_template(model_name: str) -> str: - if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: + if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: return DEFAULT_TEMPLATE[get_prefix(model_name)] return "default" diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index d16fa3d1..b78dc831 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,7 +1,7 @@ import gradio as gr from typing import TYPE_CHECKING, Dict, Generator, List -from llmtuner.tuner import export_model +from llmtuner.train import export_model from llmtuner.webui.common import get_save_dir from llmtuner.webui.locales import ALERTS diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index c6299cab..0cbd291a 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -1,8 +1,8 @@ import gradio as gr from typing import TYPE_CHECKING, Dict +from llmtuner.data.template import templates from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS -from llmtuner.extras.template import templates from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config from llmtuner.webui.utils import can_quantize diff --git a/src/llmtuner/webui/css.py b/src/llmtuner/webui/css.py index c86fb96b..36e3d4c2 100644 --- a/src/llmtuner/webui/css.py +++ b/src/llmtuner/webui/css.py @@ -1,4 +1,11 @@ CSS = r""" +.duplicate-button { + margin: auto !important; + color: white !important; + background: black !important; + border-radius: 100vh !important; +} + .modal-box { position: fixed !important; top: 50%; diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 661dfb48..3e9f077d 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time class Engine: - def __init__(self, pure_chat: Optional[bool] = False) -> None: + def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None: self.pure_chat = pure_chat - self.manager: "Manager" = Manager() - self.runner: "Runner" = Runner(self.manager) - self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) + self.manager = Manager() + self.runner = Runner(self.manager, demo_mode=demo_mode) + self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat)) def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index ba663f24..74ac59a0 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -1,4 +1,5 @@ import gradio as gr +from typing import Optional from transformers.utils.versions import require_version from llmtuner.webui.components import ( @@ -17,24 +18,35 @@ from llmtuner.webui.engine import Engine require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") -def create_ui() -> gr.Blocks: - engine = Engine(pure_chat=False) +def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks: + engine = Engine(demo_mode=demo_mode, pure_chat=False) with gr.Blocks(title="LLaMA Board", css=CSS) as demo: + if demo_mode: + gr.HTML( + "