Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011

Refactor llmtuner, support full-parameter RLHF

Former-commit-id: f04bc2a42815d02c22935fec8bf7b81438c8ba79
This commit is contained in:
hoshi-hiyouga 2023-11-16 15:47:13 +08:00 committed by GitHub
commit 550293badc
74 changed files with 806 additions and 560 deletions

View File

@ -14,7 +14,9 @@
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Preview LLaMA Board at **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches
@ -79,9 +81,9 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA.
@ -122,6 +124,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)

View File

@ -14,7 +14,9 @@
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
通过 **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该界面目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
>
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
## 训练方法
@ -79,9 +81,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
@ -122,6 +124,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)

View File

@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
def _info(self):
features = datasets.Features({
"instruction": datasets.Value("string"),
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
})
return datasets.DatasetInfo(
description=_DESCRIPTION,
@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f):
data = json.loads(row)
conversations = []
prompt = data["instruction"].strip()
response = data["output"].strip()
@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip()
history = []
conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:")
@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp))
conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else:
break
prompt = prompt[:human_idx].strip()
yield key, {
"instruction": query,
"output": response,
"history": history
}
yield key, {"conversations": conversations}

View File

@ -88,11 +88,7 @@
},
"belle_multiturn": {
"script_url": "belle_multiturn",
"columns": {
"prompt": "instruction",
"response": "output",
"history": "history"
}
"formatting": "sharegpt"
},
"ultra_chat": {
"script_url": "ultra_chat",
@ -107,6 +103,13 @@
"alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT"
},
"openorca": {
"hf_hub_url": "Open-Orca/OpenOrca",
"columns": {
"prompt": "question",
"response": "response"
}
},
"mathinstruct": {
"hf_hub_url": "TIGER-Lab/MathInstruct",
"columns": {

View File

@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
"from": "human" if i % 2 == 0 else "gpt",
"value": content[i]
} for i in range(len(content))]
yield key, {
"conversations": conversations
}
yield key, {"conversations": conversations}

View File

@ -1,9 +1,9 @@
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
# Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel
from llmtuner.eval import Evaluator
from llmtuner.tuner import export_model, run_exp
from llmtuner.train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo

View File

@ -1,14 +1,8 @@
import json
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple
from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import (
Role,
Finish,
@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage
)
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()
@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI:
def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content
else:
system = None
@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content])
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream:
generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat(
responses = chat_model.chat(
query, history, system,
do_sample=request.do_sample,
temperature=request.temperature,
@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
num_return_sequences=request.n
)
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length,
completion_tokens=response_length,
total_tokens=prompt_length+response_length
)
choices = [ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP
) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):

View File

@ -1 +1 @@
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat.chat_model import ChatModel

View File

@ -1,11 +1,21 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel:
@ -18,7 +28,7 @@ class ChatModel:
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt
def process_args(
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
@ -79,17 +89,30 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
**input_kwargs
) -> Tuple[List[str], Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
response_length = 0
for i in range(len(response_ids)):
response = self.tokenizer.batch_decode(
response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
return response, (prompt_length, response_length)
return results
@torch.inference_mode()
def stream_chat(
@ -99,7 +122,7 @@ class ChatModel:
system: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer

View File

@ -0,0 +1,4 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:

View File

@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un
from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset

View File

@ -225,9 +225,6 @@ def get_template_and_fix_tokenizer(
return template
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template(
name="alpaca",
prefix=[
@ -246,11 +243,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
"""
register_template(
name="aquila",
prefix=[
@ -273,9 +265,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix=[
@ -292,10 +281,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
"""
register_template(
name="baichuan2",
prefix=[
@ -312,9 +297,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix=[
@ -330,9 +312,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
"""
register_template(
name="bluelm",
prefix=[
@ -348,9 +327,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
@ -369,9 +345,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/THUDM/chatglm3-6b
"""
register_template(
name="chatglm3",
prefix=[
@ -395,11 +368,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
"""
register_template(
name="deepseek",
prefix=[
@ -426,9 +394,6 @@ register_template(
)
r"""
Default template.
"""
register_template(
name="default",
prefix=[
@ -447,9 +412,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
"""
register_template(
name="falcon",
prefix=[
@ -466,10 +428,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
"""
register_template(
name="intern",
prefix=[
@ -492,11 +450,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix=[
@ -519,10 +472,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
"""
register_template(
name="llama2_zh",
prefix=[
@ -536,9 +485,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"""
register_template(
name="mistral",
prefix=[
@ -552,9 +498,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/openchat/openchat_3.5
"""
register_template(
name="openchat",
prefix=[
@ -576,10 +519,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
"""
register_template(
name="qwen",
prefix=[
@ -606,10 +545,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template(
name="starchat",
prefix=[
@ -650,10 +585,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-v1.5
"""
register_template(
name="vicuna",
prefix=[
@ -670,10 +601,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template(
name="xverse",
prefix=[
@ -687,11 +614,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/wenge-research/yayi-7b
https://huggingface.co/wenge-research/yayi-7b-llama2
https://huggingface.co/wenge-research/yayi-13b-llama2
"""
register_template(
name="yayi",
prefix=[
@ -724,10 +646,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
"""
register_template(
name="zephyr",
prefix=[
@ -746,11 +664,6 @@ register_template(
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
"""
register_template(
name="ziya",
prefix=[

View File

@ -1,3 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@ -1 +1 @@
from llmtuner.eval.engine import Evaluator
from llmtuner.eval.evaluator import Evaluator

View File

@ -1,3 +0,0 @@
CHOICES = ["A", "B", "C", "D"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.eval.constants import CHOICES, SUBJECTS
from llmtuner.eval.parser import get_eval_args
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.misc import dispatch_model
from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
class Evaluator:

View File

@ -1,49 +0,0 @@
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
)
def parse_eval_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
))
return parse_args(parser, args)
def get_eval_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
]:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.eval.constants import CHOICES
from llmtuner.extras.constants import CHOICES
if TYPE_CHECKING:
from datasets import Dataset

View File

@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback):
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir)
model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
model = kwargs.pop("model")
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
model.pretrained_model.save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):

View File

@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict
from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
@ -16,14 +28,6 @@ TRAINING_STAGES = {
"Pre-Training": "pt"
}
LAYERNORM_NAMES = {"norm", "ln"}
SUPPORTED_MODELS = OrderedDict()
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
def register_model_group(
models: Dict[str, str],
@ -116,10 +120,12 @@ register_model_group(
register_model_group(
models={
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
},
template="llama2_zh"
)
@ -190,6 +196,14 @@ register_model_group(
)
register_model_group(
models={
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
},
template="openchat"
)
register_model_group(
models={
"Phi1.5-1.3B": "microsoft/phi-1_5"
@ -217,6 +231,15 @@ register_model_group(
)
register_model_group(
models={
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
},
template="vicuna"
)
register_model_group(
models={
"XVERSE-7B": "xverse/XVERSE-7B",
@ -229,9 +252,27 @@ register_model_group(
)
register_model_group(
models={
"Yayi-7B": "wenge-research/yayi-7b-llama2",
"Yayi-13B": "wenge-research/yayi-13b-llama2"
},
template="yayi"
)
register_model_group(
models={
"Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B"
}
)
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
},
template="zephyr"
)

View File

@ -3,6 +3,9 @@ import logging
class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self):
super().__init__()
@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n"
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S"
@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler)
return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@ -13,14 +13,13 @@ try:
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING:
from transformers import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
class AverageMeter:
@ -65,16 +64,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
def get_current_device() -> str:
import accelerate
from accelerate import Accelerator
dummy_accelerator = Accelerator()
if accelerate.utils.is_xpu_available():
return "xpu:{}".format(dummy_accelerator.local_process_index)
else:
return torch.float32
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList":
@ -86,14 +83,16 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def torch_gc() -> None:
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Collects GPU memory.
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@ -107,26 +106,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
return parser.parse_args_into_dataclasses()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
def torch_gc() -> None:
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
Collects GPU memory.
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

@ -3,16 +3,19 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
is_flash_attn_2_available = False
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError:
is_flash_attn_2_available = False
logger = logging.get_logger(__name__)

View File

@ -1,11 +1,14 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)

View File

@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field
@dataclass
class FinetuningArguments:
class FreezeArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
Arguments pertaining to the freeze (partial-parameter) training.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field(
default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
name_module_trainable: Optional[str] = field(
default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
@ -49,7 +49,7 @@ class FinetuningArguments:
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
Others choices: the same as LLaMA."}
)
additional_target: Optional[str] = field(
default=None,
@ -59,30 +59,76 @@ class FinetuningArguments:
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
@dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
ppo_logger: Optional[str] = field(
default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field(
default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
ppo_whiten_rewards: Optional[bool] = field(
default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
)
dpo_ref_model: Optional[str] = field(
ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
reward_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full"]] = field(
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@ -91,15 +137,37 @@ class FinetuningArguments:
default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
def split_arg(arg):
if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
self.name_module_trainable = split_arg(self.name_module_trainable)
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""
@ -112,4 +180,5 @@ class FinetuningArguments:
r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))

View File

@ -54,22 +54,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self):
self.compute_dtype = None
@ -81,8 +69,7 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@ -0,0 +1,5 @@
# Level: loader > adapter > parser, utils
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params

View File

@ -1,18 +1,9 @@
import os
import torch
from typing import TYPE_CHECKING
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.utils import find_all_linear_modules
from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
@ -46,13 +37,23 @@ def init_adapter(
if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers")
num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
@ -100,30 +101,3 @@ def init_adapter(
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@ -15,7 +15,6 @@ from transformers import (
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
@ -24,11 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
from llmtuner.tuner.core.utils import prepare_model_for_training
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -73,6 +73,7 @@ def load_model_and_tokenizer(
)
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
@ -122,7 +123,7 @@ def load_model_and_tokenizer(
# Set FlashAttention-2
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
if LlamaPatches.is_flash_attn_2_available:
if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
@ -131,7 +132,7 @@ def load_model_and_tokenizer(
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.")
else:
logger.warning("Current model does not support FlashAttention-2.")
logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.")
@ -144,7 +145,7 @@ def load_model_and_tokenizer(
else:
logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library).
# Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@ -164,10 +165,10 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead).
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
@ -185,7 +186,7 @@ def load_model_and_tokenizer(
setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files.
# Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
@ -199,25 +200,15 @@ def load_model_and_tokenizer(
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
"summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
vhead_path = (
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
)
vhead_params = load_valuehead_params(vhead_path, model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
# Prepare model for inference
if not is_trainable:

View File

@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments,
GeneratingArguments
)
@ -19,51 +20,42 @@ from llmtuner.hparams import (
logger = get_logger(__name__)
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
))
_TRAIN_ARGS = [
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_TRAIN_CLS = Tuple[
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
]
_INFER_ARGS = [
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_INFER_CLS = Tuple[
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]
_EVAL_ARGS = [
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
_EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS)
return parse_args(parser, args)
def get_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments
]:
def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
parser = HfArgumentParser(_EVAL_ARGS)
return parse_args(parser, args)
def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging
@ -90,24 +82,19 @@ def get_train_args(
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"]:
if finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation.")
raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.")
@ -139,6 +126,9 @@ def get_train_args(
if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args
if (
training_args.local_rank != -1
@ -187,14 +177,7 @@ def get_train_args(
return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
@ -211,3 +194,17 @@ def get_infer_args(
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@ -1,21 +1,53 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules(
model: "PreTrainedModel",
quantization_bit: Optional[int] = None
) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None:
import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
@ -51,6 +83,32 @@ def generate_model_card(
}
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",

View File

@ -0,0 +1 @@
from llmtuner.train.tuner import export_model, run_exp

View File

@ -0,0 +1 @@
from llmtuner.train.dpo.workflow import run_dpo

View File

@ -43,7 +43,11 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

View File

@ -4,23 +4,20 @@ from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
@ -38,23 +35,10 @@ def run_dpo(
)
# Create reference model
if finetuning_args.dpo_ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.dpo_ref_model,
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
elif training_args.do_train:
if isinstance(model, PeftModel):
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from the model itself.")
else:
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
# Update arguments
training_args_dict = training_args.to_dict()
@ -80,14 +64,13 @@ def run_dpo(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)

View File

@ -0,0 +1 @@
from llmtuner.train.ppo.workflow import run_ppo

View File

@ -3,7 +3,7 @@ import sys
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -37,24 +37,43 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
self.args = training_args
self.model_args = model_args
self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
if reward_model is not None:
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@ -213,11 +232,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
if self.reward_model is None:
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
@ -228,7 +250,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")
if self.reward_model is None:
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_device_cache()

View File

@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -33,6 +34,11 @@ def run_ppo(
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
@ -47,9 +53,11 @@ def run_ppo(
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
@ -73,9 +81,10 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=None,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
@ -88,5 +97,5 @@ def run_ppo(
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -0,0 +1 @@
from llmtuner.train.pt.workflow import run_pt

View File

@ -4,9 +4,9 @@ import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -42,7 +42,7 @@ def run_pt(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@ -0,0 +1 @@
from llmtuner.train.rm.workflow import run_rm

View File

@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwiseTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.train.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -51,7 +51,7 @@ def run_rm(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@ -0,0 +1 @@
from llmtuner.train.sft.workflow import run_sft

View File

@ -2,15 +2,23 @@ import numpy as np
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass
class ComputeMetrics:

View File

@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -69,7 +69,7 @@ def run_sft(
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss:
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation

View File

@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.train.pt import run_pt
from llmtuner.train.sft import run_sft
from llmtuner.train.rm import run_rm
from llmtuner.train.ppo import run_ppo
from llmtuner.train.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
try:
tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir)
tokenizer.save_pretrained(finetuning_args.export_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@ -0,0 +1,80 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@ -1 +0,0 @@
from llmtuner.tuner.tune import export_model, run_exp

View File

@ -1,3 +0,0 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer
from llmtuner.tuner.core.utils import generate_model_card

View File

@ -1 +0,0 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@ -1 +0,0 @@
from llmtuner.tuner.ppo.workflow import run_ppo

View File

@ -1 +0,0 @@
from llmtuner.tuner.pt.workflow import run_pt

View File

@ -1 +0,0 @@
from llmtuner.tuner.rm.workflow import run_rm

View File

@ -1 +0,0 @@
from llmtuner.tuner.sft.workflow import run_sft

View File

@ -2,7 +2,7 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
@ -14,14 +14,24 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
if not lazy_init:
if not lazy_init: # read arguments from command line
super().__init__()
if demo_mode: # load openchat 3.5 by default
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
@property
def loaded(self) -> bool:
return self.model is not None
@ -36,6 +46,8 @@ class WebChatModel(ChatModel):
error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"):
error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
if error:
gr.Warning(error)
@ -67,6 +79,11 @@ class WebChatModel(ChatModel):
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_name("top.lang")]
if self.demo_mode:
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None

View File

@ -70,7 +70,7 @@ def get_module(model_name: str) -> str:
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default"

View File

@ -1,7 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.tuner import export_model
from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS

View File

@ -1,8 +1,8 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize

View File

@ -1,4 +1,11 @@
CSS = r"""
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
.modal-box {
position: fixed !important;
top: 50%;

View File

@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time
class Engine:
def __init__(self, pure_chat: Optional[bool] = False) -> None:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.pure_chat = pure_chat
self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
self.manager = Manager()
self.runner = Runner(self.manager, demo_mode=demo_mode)
self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}

View File

@ -1,4 +1,5 @@
import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
@ -17,24 +18,35 @@ from llmtuner.webui.engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
def create_ui() -> gr.Blocks:
engine = Engine(pure_chat=False)
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"):
with gr.Tab("Evaluate & Predict"):
engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"):
engine.manager.all_elems["infer"] = create_infer_tab(engine)
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
if not demo_mode:
with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, outputs=engine.manager.list_elems())
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)

View File

@ -659,6 +659,10 @@ ALERTS = {
"en": "Failed.",
"zh": "训练出错。"
},
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"

View File

@ -4,7 +4,7 @@ import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@ -24,13 +24,13 @@ if TYPE_CHECKING:
class Runner:
def __init__(self, manager: "Manager") -> None:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.thread: "Thread" = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """
self.aborted = False
self.running = False
@ -46,9 +46,8 @@ class Runner:
def set_abort(self) -> None:
self.aborted = True
self.running = False
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
@ -65,6 +64,9 @@ class Runner:
if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang]
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
self.aborted = False
self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
@ -72,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None
self.running_data = None
self.running = False
torch_gc()
if self.aborted:
@ -84,9 +87,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
else:
checkpoint_dir = None
@ -136,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta")
@ -154,9 +160,9 @@ class Runner:
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
checkpoint_dir = ",".join([get_save_dir(
get("top.model_name"), get("top.finetuning_type"), ckpt
) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
)
@ -196,7 +202,7 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
@ -205,16 +211,14 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train)
error = self._initialize(data, do_train, from_preview=False)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.running = True
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()
@ -232,7 +236,12 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
while self.thread.is_alive():
time.sleep(2)
if self.aborted:

View File

@ -1,17 +1,20 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")

View File

@ -7,12 +7,13 @@ import fire
import math
import torch
from tqdm import tqdm
from typing import Optional
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.model import get_train_args, load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@ -22,14 +23,16 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool # mistral model uses a smaller learning rate
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data"
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir"