Merge #1525 from hiyouga/dev, fix #224 #336 #931 #936 #1011

Refactor llmtuner, support full-parameter RLHF

Former-commit-id: f04bc2a42815d02c22935fec8bf7b81438c8ba79
This commit is contained in:
hoshi-hiyouga 2023-11-16 15:47:13 +08:00 committed by GitHub
commit 550293badc
74 changed files with 806 additions and 560 deletions

View File

@ -14,7 +14,9 @@
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory ## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet) Preview LLaMA Board at **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)**.
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU. Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
> >
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models. > For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported. Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list of models we supported.
## Supported Training Approaches ## Supported Training Approaches
@ -79,9 +81,9 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| Reward Modeling | | | :white_check_mark: | :white_check_mark: | | Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO Training | | | :white_check_mark: | :white_check_mark: | | PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: | | DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> Use `--quantization_bit 4/8` argument to enable QLoRA. > Use `--quantization_bit 4/8` argument to enable QLoRA.
@ -122,6 +124,7 @@ Please refer to [template.py](src/llmtuner/extras/template.py) for a full list o
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)

View File

@ -14,7 +14,9 @@
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory ## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练) 通过 **[Hugging Face Space](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。该界面目前仅支持单卡训练
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。 下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
@ -71,7 +73,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
> >
> 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用**对应的模板**。
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。 项目所支持模型的完整列表请参阅 [constants.py](src/llmtuner/extras/constants.py)。
## 训练方法 ## 训练方法
@ -79,9 +81,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ | | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: | | 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| PPO 训练 | | | :white_check_mark: | :white_check_mark: | | PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: | | DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
> [!NOTE] > [!NOTE]
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。 > 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
@ -122,6 +124,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus) - [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)

View File

@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
def _info(self): def _info(self):
features = datasets.Features({ features = datasets.Features({
"instruction": datasets.Value("string"), "conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
"output": datasets.Value("string"),
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
}) })
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
with open(filepath, "r", encoding="utf-8") as f: with open(filepath, "r", encoding="utf-8") as f:
for key, row in enumerate(f): for key, row in enumerate(f):
data = json.loads(row) data = json.loads(row)
conversations = []
prompt = data["instruction"].strip() prompt = data["instruction"].strip()
response = data["output"].strip() response = data["output"].strip()
@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
human_idx = prompt.rfind("Human:") human_idx = prompt.rfind("Human:")
query = prompt[human_idx+6:assist_idx].strip() query = prompt[human_idx+6:assist_idx].strip()
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
history = [] conversations.insert(0, {"from": "gpt", "value": response})
conversations.insert(0, {"from": "human", "value": query})
while prompt.rfind("Assistant:") != -1: while prompt.rfind("Assistant:") != -1:
assist_idx = prompt.rfind("Assistant:") assist_idx = prompt.rfind("Assistant:")
@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
if human_idx != -1: if human_idx != -1:
old_query = prompt[human_idx+6:assist_idx].strip() old_query = prompt[human_idx+6:assist_idx].strip()
old_resp = prompt[assist_idx+10:].strip() old_resp = prompt[assist_idx+10:].strip()
history.insert(0, (old_query, old_resp)) conversations.insert(0, {"from": "gpt", "value": old_resp})
conversations.insert(0, {"from": "human", "value": old_query})
else: else:
break break
prompt = prompt[:human_idx].strip() prompt = prompt[:human_idx].strip()
yield key, { yield key, {"conversations": conversations}
"instruction": query,
"output": response,
"history": history
}

View File

@ -88,11 +88,7 @@
}, },
"belle_multiturn": { "belle_multiturn": {
"script_url": "belle_multiturn", "script_url": "belle_multiturn",
"columns": { "formatting": "sharegpt"
"prompt": "instruction",
"response": "output",
"history": "history"
}
}, },
"ultra_chat": { "ultra_chat": {
"script_url": "ultra_chat", "script_url": "ultra_chat",
@ -107,6 +103,13 @@
"alpaca_cot": { "alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT" "hf_hub_url": "QingyiSi/Alpaca-CoT"
}, },
"openorca": {
"hf_hub_url": "Open-Orca/OpenOrca",
"columns": {
"prompt": "question",
"response": "response"
}
},
"mathinstruct": { "mathinstruct": {
"hf_hub_url": "TIGER-Lab/MathInstruct", "hf_hub_url": "TIGER-Lab/MathInstruct",
"columns": { "columns": {

View File

@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
"from": "human" if i % 2 == 0 else "gpt", "from": "human" if i % 2 == 0 else "gpt",
"value": content[i] "value": content[i]
} for i in range(len(content))] } for i in range(len(content))]
yield key, { yield key, {"conversations": conversations}
"conversations": conversations
}

View File

@ -1,9 +1,9 @@
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams # Level: api, webui > chat, eval, train > data, model > extras, hparams
from llmtuner.api import create_app from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.eval import Evaluator from llmtuner.eval import Evaluator
from llmtuner.tuner import export_model, run_exp from llmtuner.train import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo from llmtuner.webui import create_ui, create_web_demo

View File

@ -1,14 +1,8 @@
import json import json
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse
from typing import List, Tuple from typing import List, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from contextlib import asynccontextmanager
from llmtuner.extras.misc import torch_gc
from llmtuner.chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role, Role,
Finish, Finish,
@ -23,10 +17,28 @@ from llmtuner.api.protocol import (
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage ChatCompletionResponseUsage
) )
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.extras.packages import (
is_fastapi_availble, is_starlette_available, is_uvicorn_available
)
if is_fastapi_availble():
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
if is_starlette_available():
from sse_starlette import EventSourceResponse
if is_uvicorn_available():
import uvicorn
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory async def lifespan(app: "FastAPI"): # collects GPU memory
yield yield
torch_gc() torch_gc()
@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str:
return data.json(exclude_unset=True, ensure_ascii=False) return data.json(exclude_unset=True, ensure_ascii=False)
def create_app(chat_model: ChatModel) -> FastAPI: def create_app(chat_model: "ChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role != Role.USER: if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: if len(prev_messages) and prev_messages[0].role == Role.SYSTEM:
system = prev_messages.pop(0).content system = prev_messages.pop(0).content
else: else:
system = None system = None
@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
else: else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
if request.stream: if request.stream:
generate = predict(query, history, system, request) generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat( responses = chat_model.chat(
query, history, system, query, history, system,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI:
num_return_sequences=request.n num_return_sequences=request.n
) )
prompt_length, response_length = 0, 0
choices = []
for i, response in enumerate(responses):
choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
))
prompt_length = response.prompt_length
response_length += response.response_length
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=prompt_length, prompt_tokens=prompt_length,
completion_tokens=response_length, completion_tokens=response_length,
total_tokens=prompt_length+response_length total_tokens=prompt_length+response_length
) )
choices = [ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role=Role.ASSISTANT, content=choice),
finish_reason=Finish.STOP
) for i, choice in enumerate(response)]
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):

View File

@ -1 +1 @@
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.chat_model import ChatModel

View File

@ -1,11 +1,21 @@
import torch import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.misc import get_logits_processor
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
finish_reason: Literal["stop", "length"]
class ChatModel: class ChatModel:
@ -18,7 +28,7 @@ class ChatModel:
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt self.system_prompt = data_args.system_prompt
def process_args( def _process_args(
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
@ -79,17 +89,30 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[List[str], Tuple[int, int]]: ) -> List[Response]:
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) response = self.tokenizer.batch_decode(
response_length = 0 response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
for i in range(len(response_ids)): )
results = []
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
return response, (prompt_length, response_length) return results
@torch.inference_mode() @torch.inference_mode()
def stream_chat( def stream_chat(
@ -99,7 +122,7 @@ class ChatModel:
system: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@ -0,0 +1,4 @@
from llmtuner.data.loader import get_dataset
from llmtuner.data.preprocess import preprocess_dataset
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.data.utils import split_dataset

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE from llmtuner.data.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un
from datasets import load_from_disk from datasets import load_from_disk
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset

View File

@ -225,9 +225,6 @@ def get_template_and_fix_tokenizer(
return template return template
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
"""
register_template( register_template(
name="alpaca", name="alpaca",
prefix=[ prefix=[
@ -246,11 +243,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/BAAI/AquilaChat-7B
https://huggingface.co/BAAI/AquilaChat2-7B
https://huggingface.co/BAAI/AquilaChat2-34B
"""
register_template( register_template(
name="aquila", name="aquila",
prefix=[ prefix=[
@ -273,9 +265,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template( register_template(
name="baichuan", name="baichuan",
prefix=[ prefix=[
@ -292,10 +281,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
"""
register_template( register_template(
name="baichuan2", name="baichuan2",
prefix=[ prefix=[
@ -312,9 +297,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template( register_template(
name="belle", name="belle",
prefix=[ prefix=[
@ -330,9 +312,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
"""
register_template( register_template(
name="bluelm", name="bluelm",
prefix=[ prefix=[
@ -348,9 +327,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template( register_template(
name="chatglm2", name="chatglm2",
prefix=[ prefix=[
@ -369,9 +345,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/THUDM/chatglm3-6b
"""
register_template( register_template(
name="chatglm3", name="chatglm3",
prefix=[ prefix=[
@ -395,11 +368,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
"""
register_template( register_template(
name="deepseek", name="deepseek",
prefix=[ prefix=[
@ -426,9 +394,6 @@ register_template(
) )
r"""
Default template.
"""
register_template( register_template(
name="default", name="default",
prefix=[ prefix=[
@ -447,9 +412,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
"""
register_template( register_template(
name="falcon", name="falcon",
prefix=[ prefix=[
@ -466,10 +428,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
https://huggingface.co/internlm/internlm-chat-20b
"""
register_template( register_template(
name="intern", name="intern",
prefix=[ prefix=[
@ -492,11 +450,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template( register_template(
name="llama2", name="llama2",
prefix=[ prefix=[
@ -519,10 +472,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
"""
register_template( register_template(
name="llama2_zh", name="llama2_zh",
prefix=[ prefix=[
@ -536,9 +485,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
"""
register_template( register_template(
name="mistral", name="mistral",
prefix=[ prefix=[
@ -552,9 +498,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/openchat/openchat_3.5
"""
register_template( register_template(
name="openchat", name="openchat",
prefix=[ prefix=[
@ -576,10 +519,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
https://huggingface.co/Qwen/Qwen-14B-Chat
"""
register_template( register_template(
name="qwen", name="qwen",
prefix=[ prefix=[
@ -606,10 +545,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
https://huggingface.co/HuggingFaceH4/starchat-beta
"""
register_template( register_template(
name="starchat", name="starchat",
prefix=[ prefix=[
@ -650,10 +585,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
https://huggingface.co/lmsys/vicuna-13b-v1.5
"""
register_template( register_template(
name="vicuna", name="vicuna",
prefix=[ prefix=[
@ -670,10 +601,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
https://huggingface.co/xverse/XVERSE-13B-Chat
"""
register_template( register_template(
name="xverse", name="xverse",
prefix=[ prefix=[
@ -687,11 +614,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/wenge-research/yayi-7b
https://huggingface.co/wenge-research/yayi-7b-llama2
https://huggingface.co/wenge-research/yayi-13b-llama2
"""
register_template( register_template(
name="yayi", name="yayi",
prefix=[ prefix=[
@ -724,10 +646,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
"""
register_template( register_template(
name="zephyr", name="zephyr",
prefix=[ prefix=[
@ -746,11 +664,6 @@ register_template(
) )
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
"""
register_template( register_template(
name="ziya", name="ziya",
prefix=[ prefix=[

View File

@ -1,3 +0,0 @@
from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@ -1 +1 @@
from llmtuner.eval.engine import Evaluator from llmtuner.eval.evaluator import Evaluator

View File

@ -1,3 +0,0 @@
CHOICES = ["A", "B", "C", "D"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset from datasets import load_dataset
from transformers.utils import cached_file from transformers.utils import cached_file
from llmtuner.eval.constants import CHOICES, SUBJECTS from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.parser import get_eval_args
from llmtuner.eval.template import get_eval_template from llmtuner.eval.template import get_eval_template
from llmtuner.extras.misc import dispatch_model from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
from llmtuner.tuner.core import load_model_and_tokenizer
class Evaluator: class Evaluator:

View File

@ -1,49 +0,0 @@
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser
from llmtuner.extras.misc import parse_args
from llmtuner.hparams import (
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
)
def parse_eval_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
))
return parse_args(parser, args)
def get_eval_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
EvaluationArguments,
FinetuningArguments
]:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from llmtuner.eval.constants import CHOICES from llmtuner.extras.constants import CHOICES
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset

View File

@ -12,6 +12,7 @@ from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__) logger = get_logger(__name__)
@ -25,18 +26,24 @@ class SavePeftModelCallback(TrainerCallback):
""" """
if args.should_save: if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
model = kwargs.pop("model") model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(output_dir)
if getattr(model, "is_peft_model", False): if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir) model.pretrained_model.save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of training. Event called at the end of training.
""" """
if args.should_save: if args.should_save:
model = kwargs.pop("model") model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
model.pretrained_model.config.save_pretrained(args.output_dir)
if model.pretrained_model.can_generate():
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
if getattr(model, "is_peft_model", False): if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir) model.pretrained_model.save_pretrained(args.output_dir)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):

View File

@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict
from typing import Dict, Optional from typing import Dict, Optional
CHOICES = ["A", "B", "C", "D"]
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LAYERNORM_NAMES = {"norm", "ln"}
LOG_FILE_NAME = "trainer_log.jsonl" LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
SUPPORTED_MODELS = OrderedDict()
TRAINING_STAGES = { TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft", "Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm", "Reward Modeling": "rm",
@ -16,14 +28,6 @@ TRAINING_STAGES = {
"Pre-Training": "pt" "Pre-Training": "pt"
} }
LAYERNORM_NAMES = {"norm", "ln"}
SUPPORTED_MODELS = OrderedDict()
DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str)
def register_model_group( def register_model_group(
models: Dict[str, str], models: Dict[str, str],
@ -116,10 +120,12 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", "ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", "ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", "ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b" "ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
}, },
template="llama2_zh" template="llama2_zh"
) )
@ -190,6 +196,14 @@ register_model_group(
) )
register_model_group(
models={
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
},
template="openchat"
)
register_model_group( register_model_group(
models={ models={
"Phi1.5-1.3B": "microsoft/phi-1_5" "Phi1.5-1.3B": "microsoft/phi-1_5"
@ -217,6 +231,15 @@ register_model_group(
) )
register_model_group(
models={
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
},
template="vicuna"
)
register_model_group( register_model_group(
models={ models={
"XVERSE-7B": "xverse/XVERSE-7B", "XVERSE-7B": "xverse/XVERSE-7B",
@ -229,9 +252,27 @@ register_model_group(
) )
register_model_group(
models={
"Yayi-7B": "wenge-research/yayi-7b-llama2",
"Yayi-13B": "wenge-research/yayi-13b-llama2"
},
template="yayi"
)
register_model_group( register_model_group(
models={ models={
"Yi-6B": "01-ai/Yi-6B", "Yi-6B": "01-ai/Yi-6B",
"Yi-34B": "01-ai/Yi-34B" "Yi-34B": "01-ai/Yi-34B"
} }
) )
register_model_group(
models={
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
"Zephyr-7B-Beta-Chat": "HuggingFaceH4/zephyr-7b-beta"
},
template="zephyr"
)

View File

@ -3,6 +3,9 @@ import logging
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r"""
Logger handler used in Web UI.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n" self.log += "\n\n"
def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
r"""
Gets a standard logger with a stream hander to stdout.
"""
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S" datefmt="%m/%d/%Y %H:%M:%S"
@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
logger.addHandler(handler) logger.addHandler(handler)
return logger return logger
def reset_logging() -> None:
r"""
Removes basic config of root logger. (unused in script)
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))

View File

@ -13,14 +13,13 @@ try:
is_torch_npu_available is_torch_npu_available
) )
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available()
except ImportError: except ImportError:
_is_fp16_available = torch.cuda.is_available() _is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported() _is_bf16_available = torch.cuda.is_bf16_supported()
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import HfArgumentParser from transformers import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
class AverageMeter: class AverageMeter:
@ -65,16 +64,14 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param return trainable_params, all_param
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: def get_current_device() -> str:
r""" import accelerate
Infers the optimal dtype according to the model_dtype and device compatibility. from accelerate import Accelerator
""" dummy_accelerator = Accelerator()
if _is_bf16_available and model_dtype == torch.bfloat16: if accelerate.utils.is_xpu_available():
return torch.bfloat16 return "xpu:{}".format(dummy_accelerator.local_process_index)
elif _is_fp16_available:
return torch.float16
else: else:
return torch.float32 return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
def get_logits_processor() -> "LogitsProcessorList": def get_logits_processor() -> "LogitsProcessorList":
@ -86,14 +83,16 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor return logits_processor
def torch_gc() -> None: def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r""" r"""
Collects GPU memory. Infers the optimal dtype according to the model_dtype and device compatibility.
""" """
gc.collect() if _is_bf16_available and model_dtype == torch.bfloat16:
if torch.cuda.is_available(): return torch.bfloat16
torch.cuda.empty_cache() elif _is_fp16_available:
torch.cuda.ipc_collect() return torch.float16
else:
return torch.float32
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@ -107,26 +106,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None
return parser.parse_args_into_dataclasses() return parser.parse_args_into_dataclasses()
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": def torch_gc() -> None:
r""" r"""
Dispatches a pre-trained model to GPUs with balanced memory. Collects GPU memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
""" """
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing gc.collect()
return model if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.cuda.device_count() > 1: torch.cuda.ipc_collect()
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()

View File

@ -0,0 +1,55 @@
import importlib.metadata
import importlib.util
def is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_rouge_available = is_package_available("rouge-chinese")
_starlette_available = is_package_available("sse-starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
def is_flash_attn2_available():
return _flash_attn2_available
def is_jieba_available():
return _jieba_available
def is_matplotlib_available():
return _matplotlib_available
def is_nltk_available():
return _nltk_available
def is_rouge_available():
return _rouge_available
def is_starlette_available():
return _starlette_available
def is_uvicorn_available():
return _uvicorn_available

View File

@ -3,16 +3,19 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional, Tuple from typing import Optional, Tuple
from transformers.utils import logging from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
is_flash_attn_2_available = False
try: try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError:
is_flash_attn_2_available = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -1,11 +1,14 @@
import os import os
import math import math
import json import json
import matplotlib.pyplot as plt
from typing import List, Optional from typing import List, Optional
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.packages import is_matplotlib_available
if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@ -4,38 +4,38 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class FinetuningArguments: class FreezeArguments:
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to the freeze (partial-parameter) training.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[str] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \
LLaMA choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
Qwen choices: [\"mlp\", \"attn\"], \ Qwen choices: [\"mlp\", \"attn\"], \
Phi-1.5 choices: [\"mlp\", \"mixer\"], \ Phi-1.5 choices: [\"mlp\", \"mixer\"], \
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} Others choices: the same as LLaMA."}
) )
@dataclass
class LoraArguments:
r"""
Arguments pertaining to the LoRA training.
"""
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
) )
lora_alpha: Optional[float] = field( lora_alpha: Optional[float] = field(
default=32.0, default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2.0)."}
) )
lora_dropout: Optional[float] = field( lora_dropout: Optional[float] = field(
default=0.1, default=0.1,
@ -49,7 +49,7 @@ class FinetuningArguments:
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."} Others choices: the same as LLaMA."}
) )
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
@ -59,30 +59,76 @@ class FinetuningArguments:
default=True, default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
) )
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."} @dataclass
class RLHFArguments:
r"""
Arguments pertaining to the PPO and DPO training.
"""
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
) )
ppo_logger: Optional[str] = field( ppo_logger: Optional[str] = field(
default=None, default=None,
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
) )
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO training."}
)
ppo_target: Optional[float] = field( ppo_target: Optional[float] = field(
default=6.0, default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."} metadata={"help": "Target KL value for adaptive KL control in PPO training."}
) )
dpo_beta: Optional[float] = field( ppo_whiten_rewards: Optional[bool] = field(
default=0.1, default=False,
metadata={"help": "The beta parameter for the DPO loss."} metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
) )
dpo_ref_model: Optional[str] = field( ref_model: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the reference model used for the DPO training."} metadata={"help": "Path to the reference model used for the PPO or DPO training."}
) )
dpo_ref_model_checkpoint: Optional[str] = field( ref_model_checkpoint: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."} metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
) )
ref_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reference model."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
reward_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
)
reward_model_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the reward model."}
)
reward_model_type: Optional[Literal["lora", "full"]] = field(
default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field( upcast_layernorm: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."} metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@ -91,15 +137,37 @@ class FinetuningArguments:
default=0, default=0,
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
) )
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA def split_arg(arg):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] if isinstance(arg, str):
return [item.strip() for item in arg.split(",")]
return arg
if isinstance(self.additional_target, str): self.name_module_trainable = split_arg(self.name_module_trainable)
self.additional_target = [target.strip() for target in self.additional_target.split(",")] self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target)
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Lora reward model only supports lora training.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
@ -112,4 +180,5 @@ class FinetuningArguments:
r"""Creates an instance from the content of `json_path`.""" r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@ -54,22 +54,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
) )
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
) )
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self): def __post_init__(self):
self.compute_dtype = None self.compute_dtype = None
@ -81,8 +69,7 @@ class ModelArguments:
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None: assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -0,0 +1,5 @@
# Level: loader > adapter > parser, utils
from llmtuner.model.loader import load_model_and_tokenizer
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
from llmtuner.model.utils import dispatch_model, generate_model_card, load_valuehead_params

View File

@ -1,18 +1,9 @@
import os
import torch import torch
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.utils import find_all_linear_modules from llmtuner.model.utils import find_all_linear_modules
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
@ -46,13 +37,23 @@ def init_adapter(
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
num_layers = getattr(model.config, "num_layers") num_layers = (
getattr(model.config, "num_hidden_layers", None)
or getattr(model.config, "num_layers", None)
or getattr(model.config, "n_layer", None)
)
if not num_layers:
raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers): if not any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False) param.requires_grad_(False)
@ -100,30 +101,3 @@ def init_adapter(
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model return model
def load_valuehead_params(
model: "PreTrainedModel",
model_args: "ModelArguments"
) -> bool:
kwargs = {
"path_or_repo_id": model_args.reward_model,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token,
"revision": model_args.model_revision
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@ -15,7 +15,6 @@ from transformers import (
) )
from transformers.models.llama import modeling_llama as LlamaModule from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
try: try:
@ -24,11 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, infer_optim_dtype from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params from llmtuner.model.adapter import init_adapter
from llmtuner.tuner.core.utils import prepare_model_for_training from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -73,6 +73,7 @@ def load_model_and_tokenizer(
) )
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]
else: else:
model_to_load = model_args.model_name_or_path model_to_load = model_args.model_name_or_path
@ -122,7 +123,7 @@ def load_model_and_tokenizer(
# Set FlashAttention-2 # Set FlashAttention-2
if model_args.flash_attn: if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama": if getattr(config, "model_type", None) == "llama":
if LlamaPatches.is_flash_attn_2_available: if is_flash_attn2_available():
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.") logger.info("Using FlashAttention-2 for faster training and inference.")
@ -131,7 +132,7 @@ def load_model_and_tokenizer(
elif getattr(config, "model_type", None) in ["qwen", "Yi"]: elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.") logger.info("Current model automatically enables FlashAttention if installed.")
else: else:
logger.warning("Current model does not support FlashAttention-2.") logger.warning("Current model does not support FlashAttention.")
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
logger.warning("Using `--flash_attn` for faster training in large context length.") logger.warning("Using `--flash_attn` for faster training in large context length.")
@ -144,7 +145,7 @@ def load_model_and_tokenizer(
else: else:
logger.warning("Current model does not support shift short attention.") logger.warning("Current model does not support shift short attention.")
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library)
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
@ -164,10 +165,10 @@ def load_model_and_tokenizer(
bnb_4bit_quant_type=model_args.quantization_type bnb_4bit_quant_type=model_args.quantization_type
) )
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto" config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pre-trained models (without valuehead). # Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
@ -185,7 +186,7 @@ def load_model_and_tokenizer(
setattr(model, "lm_head", model.transformer.output_layer) setattr(model, "lm_head", model.transformer.output_layer)
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
# Register auto class to save the custom code files. # Register auto class to save the custom code files
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}): if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
@ -199,25 +200,15 @@ def load_model_and_tokenizer(
model = model.train() if is_trainable else model.eval() model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF # Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo": if stage in ["rm", "ppo"]:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() vhead_path = (
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
logger.warning("Only the last checkpoint containing valuehead will be loaded.") )
if load_valuehead_params(model, model_args): vhead_params = load_valuehead_params(vhead_path, model_args)
model.v_head.load_state_dict({ if vhead_params is not None:
"summary.weight": getattr(model, "reward_head_weight"), model.load_state_dict(vhead_params, strict=False)
"summary.bias": getattr(model, "reward_head_bias") logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
})
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference # Prepare model for inference
if not is_trainable: if not is_trainable:

View File

@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args
from llmtuner.hparams import ( from llmtuner.hparams import (
ModelArguments, ModelArguments,
DataArguments, DataArguments,
EvaluationArguments,
FinetuningArguments, FinetuningArguments,
GeneratingArguments GeneratingArguments
) )
@ -19,51 +20,42 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def parse_train_args( _TRAIN_ARGS = [
args: Optional[Dict[str, Any]] = None ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
) -> Tuple[ ]
ModelArguments, _TRAIN_CLS = Tuple[
DataArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
Seq2SeqTrainingArguments, ]
FinetuningArguments, _INFER_ARGS = [
GeneratingArguments ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
]: ]
parser = HfArgumentParser(( _INFER_CLS = Tuple[
ModelArguments, ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
DataArguments, ]
Seq2SeqTrainingArguments, _EVAL_ARGS = [
FinetuningArguments, ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
GeneratingArguments ]
)) _EVAL_CLS = Tuple[
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
]
def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)
return parse_args(parser, args) return parse_args(parser, args)
def parse_infer_args( def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
args: Optional[Dict[str, Any]] = None parser = HfArgumentParser(_INFER_ARGS)
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser((
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
))
return parse_args(parser, args) return parse_args(parser, args)
def get_train_args( def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
args: Optional[Dict[str, Any]] = None parser = HfArgumentParser(_EVAL_ARGS)
) -> Tuple[ return parse_args(parser, args)
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments, def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging # Setup logging
@ -90,24 +82,19 @@ def get_train_args(
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"]: if finetuning_args.stage in ["rm", "ppo"]:
if finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end: if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train: if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation.") raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
if finetuning_args.stage in ["rm", "dpo"]: if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking: if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.") raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
if finetuning_args.stage == "ppo" and model_args.shift_attn: if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
@ -139,6 +126,9 @@ def get_train_args(
if (not training_args.do_train) and model_args.quantization_bit is not None: if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args # postprocess training_args
if ( if (
training_args.local_rank != -1 training_args.local_rank != -1
@ -187,14 +177,7 @@ def get_train_args(
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args( def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
args: Optional[Dict[str, Any]] = None
) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None: if data_args.template is None:
@ -211,3 +194,17 @@ def get_infer_args(
raise ValueError("Only LoRA tuning accepts multiple checkpoints.") raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args

View File

@ -1,21 +1,53 @@
import torch import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
r"""
Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
"""
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
max_memory = get_balanced_memory(model, **kwargs)
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
return dispatch_model(model, device_map)
else:
return model.cuda()
def find_all_linear_modules( def find_all_linear_modules(
model: "PreTrainedModel", model: "PreTrainedModel",
quantization_bit: Optional[int] = None quantization_bit: Optional[int] = None
) -> List[str]: ) -> List[str]:
r"""
Finds all available modules to apply lora.
"""
if quantization_bit is not None: if quantization_bit is not None:
import bitsandbytes as bnb import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
@ -51,6 +83,32 @@ def generate_model_card(
} }
def load_valuehead_params(
path_or_repo_id: str,
model_args: "ModelArguments"
) -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
"""
kwargs = {
"path_or_repo_id": path_or_repo_id,
"cache_dir": model_args.cache_dir,
"token": model_args.hf_hub_token
}
try:
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
except:
try:
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
except:
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
return None
return torch.load(vhead_file, map_location="cpu")
def prepare_model_for_training( def prepare_model_for_training(
model: "PreTrainedModel", model: "PreTrainedModel",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",

View File

@ -0,0 +1 @@
from llmtuner.train.tuner import export_model, run_exp

View File

@ -0,0 +1 @@
from llmtuner.train.dpo.workflow import run_dpo

View File

@ -43,6 +43,10 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

View File

@ -4,23 +4,20 @@ from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments from llmtuner.hparams import ModelArguments
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.train.utils import create_ref_model
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo( def run_dpo(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
@ -38,23 +35,10 @@ def run_dpo(
) )
# Create reference model # Create reference model
if finetuning_args.dpo_ref_model is not None: if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.dpo_ref_model,
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
elif training_args.do_train:
if isinstance(model, PeftModel):
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from the model itself.")
else:
ref_model = model ref_model = model
else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
@ -80,14 +64,13 @@ def run_dpo(
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key] remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys: for key in remove_keys:
metrics.pop(key) metrics.pop(key)

View File

@ -0,0 +1 @@
from llmtuner.train.ppo.workflow import run_ppo

View File

@ -3,7 +3,7 @@ import sys
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -37,24 +37,43 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"], callbacks: List["TrainerCallback"],
reward_model: "AutoModelForCausalLMWithValueHead",
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
self.args = training_args self.args = training_args
self.model_args = model_args self.model_args = model_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.reward_model = reward_model
self.generation_config = GenerationConfig( self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict() **generating_args.to_dict()
) )
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
if self.args.max_steps > 0: if self.args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs") logger.info("max_steps is given, it will override any value given in num_train_epochs")
if reward_model is not None:
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if is_deepspeed_enabled:
if not (
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.reward_model = self._prepare_deepspeed(self.reward_model)
else:
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
def ppo_train(self) -> None: def ppo_train(self) -> None:
r""" r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
@ -213,11 +232,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
r""" r"""
Computes scores using given reward model. Computes scores using given reward model.
""" """
if self.reward_model is None:
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses) batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) reward_model = self.reward_model if self.reward_model is not None else self.model
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
@ -228,7 +250,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
end_index = end_indexes[-1].item() if len(end_indexes) else 0 end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
if self.reward_model is None:
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
return rewards return rewards
@PPODecorators.empty_device_cache() @PPODecorators.empty_device_cache()

View File

@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.model import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer from llmtuner.train.utils import create_ref_model, create_reward_model
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -33,6 +34,11 @@ def run_ppo(
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate, learning_rate=training_args.learning_rate,
@ -47,9 +53,11 @@ def run_ppo(
log_with=finetuning_args.ppo_logger, log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm, use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm, use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False} accelerator_kwargs={"step_scheduler_with_optimizer": False}
) )
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0: if training_args.max_steps > 0:
num_training_steps = training_args.max_steps num_training_steps = training_args.max_steps
@ -73,9 +81,10 @@ def run_ppo(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args, generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()], callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config, config=ppo_config,
model=model, model=model,
ref_model=None, ref_model=ref_model,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset=dataset, dataset=dataset,
data_collator=data_collator, data_collator=data_collator,
@ -88,5 +97,5 @@ def run_ppo(
ppo_trainer.ppo_train() ppo_trainer.ppo_train()
ppo_trainer.save_model() ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"]) plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@ -0,0 +1 @@
from llmtuner.train.pt.workflow import run_pt

View File

@ -4,9 +4,9 @@ import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling, Trainer from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.model import generate_model_card, load_model_and_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -42,7 +42,7 @@ def run_pt(
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation

View File

@ -0,0 +1 @@
from llmtuner.train.rm.workflow import run_rm

View File

@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.train.rm.metric import compute_accuracy
from llmtuner.tuner.rm.trainer import PairwiseTrainer from llmtuner.train.rm.trainer import PairwiseTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -51,7 +51,7 @@ def run_rm(
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation

View File

@ -0,0 +1 @@
from llmtuner.train.sft.workflow import run_sft

View File

@ -2,15 +2,23 @@ import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
import jieba
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.packages import (
is_jieba_available, is_nltk_available, is_rouge_available
)
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available():
import jieba
if is_nltk_available():
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
if is_rouge_available():
from rouge_chinese import Rouge
@dataclass @dataclass
class ComputeMetrics: class ComputeMetrics:

View File

@ -3,13 +3,13 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.model import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.train.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -69,7 +69,7 @@ def run_sft(
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and model_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
# Evaluation # Evaluation

View File

@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt from llmtuner.train.pt import run_pt
from llmtuner.tuner.sft import run_sft from llmtuner.train.sft import run_sft
from llmtuner.tuner.rm import run_rm from llmtuner.train.rm import run_rm
from llmtuner.tuner.ppo import run_ppo from llmtuner.train.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo from llmtuner.train.dpo import run_dpo
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
model_args, _, finetuning_args, _ = get_infer_args(args) model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True model.config.use_cache = True
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left" tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(model_args.export_dir) tokenizer.save_pretrained(finetuning_args.export_dir)
except: except:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@ -0,0 +1,80 @@
import torch
from typing import TYPE_CHECKING, Literal, Union
from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
logger = get_logger(__name__)
def create_ref_model(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"]
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
The valuehead parameter is randomly initialized since it is useless for PPO training.
"""
if finetuning_args.ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.ref_model,
checkpoint_dir=finetuning_args.ref_model_checkpoint,
quantization_bit=finetuning_args.ref_model_quantization_bit
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
logger.info("Created reference model from the model itself.")
return ref_model
def create_reward_model(
model: "AutoModelForCausalLMWithValueHead",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
) -> "AutoModelForCausalLMWithValueHead":
r"""
Creates reward model for PPO training.
"""
if finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
assert vhead_params is not None, "Reward model is not correctly loaded."
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
return None
else:
reward_model_args_dict = model_args.to_dict()
reward_model_args_dict.update(dict(
model_name_or_path=finetuning_args.reward_model,
checkpoint_dir=finetuning_args.reward_model_checkpoint,
quantization_bit=finetuning_args.reward_model_quantization_bit
))
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model

View File

@ -1 +0,0 @@
from llmtuner.tuner.tune import export_model, run_exp

View File

@ -1,3 +0,0 @@
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
from llmtuner.tuner.core.loader import load_model_and_tokenizer
from llmtuner.tuner.core.utils import generate_model_card

View File

@ -1 +0,0 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@ -1 +0,0 @@
from llmtuner.tuner.ppo.workflow import run_ppo

View File

@ -1 +0,0 @@
from llmtuner.tuner.pt.workflow import run_pt

View File

@ -1 +0,0 @@
from llmtuner.tuner.rm.workflow import run_rm

View File

@ -1 +0,0 @@
from llmtuner.tuner.sft.workflow import run_sft

View File

@ -2,7 +2,7 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir from llmtuner.webui.common import get_save_dir
@ -14,14 +14,24 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None: def __init__(
self,
manager: "Manager",
demo_mode: Optional[bool] = False,
lazy_init: Optional[bool] = True
) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if not lazy_init:
if not lazy_init: # read arguments from command line
super().__init__() super().__init__()
if demo_mode: # load openchat 3.5 by default
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
@property @property
def loaded(self) -> bool: def loaded(self) -> bool:
return self.model is not None return self.model is not None
@ -36,6 +46,8 @@ class WebChatModel(ChatModel):
error = ALERTS["err_no_model"][lang] error = ALERTS["err_no_model"][lang]
elif not get("top.model_path"): elif not get("top.model_path"):
error = ALERTS["err_no_path"][lang] error = ALERTS["err_no_path"][lang]
elif self.demo_mode:
error = ALERTS["err_demo"][lang]
if error: if error:
gr.Warning(error) gr.Warning(error)
@ -67,6 +79,11 @@ class WebChatModel(ChatModel):
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
lang = data[self.manager.get_elem_by_name("top.lang")] lang = data[self.manager.get_elem_by_name("top.lang")]
if self.demo_mode:
yield ALERTS["err_demo"][lang]
return
yield ALERTS["info_unloading"][lang] yield ALERTS["info_unloading"][lang]
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None

View File

@ -70,7 +70,7 @@ def get_module(model_name: str) -> str:
def get_template(model_name: str) -> str: def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[get_prefix(model_name)] return DEFAULT_TEMPLATE[get_prefix(model_name)]
return "default" return "default"

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.tuner import export_model from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS

View File

@ -1,8 +1,8 @@
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize

View File

@ -1,4 +1,11 @@
CSS = r""" CSS = r"""
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
.modal-box { .modal-box {
position: fixed !important; position: fixed !important;
top: 50%; top: 50%;

View File

@ -12,11 +12,11 @@ from llmtuner.webui.utils import get_time
class Engine: class Engine:
def __init__(self, pure_chat: Optional[bool] = False) -> None: def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
self.pure_chat = pure_chat self.pure_chat = pure_chat
self.manager: "Manager" = Manager() self.manager = Manager()
self.runner: "Runner" = Runner(self.manager) self.runner = Runner(self.manager, demo_mode=demo_mode)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) self.chatter = WebChatModel(manager=self.manager, demo_mode=demo_mode, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()} return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}

View File

@ -1,4 +1,5 @@
import gradio as gr import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner.webui.components import ( from llmtuner.webui.components import (
@ -17,22 +18,33 @@ from llmtuner.webui.engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"") require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
def create_ui() -> gr.Blocks: def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
engine = Engine(pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo: with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
if demo_mode:
gr.HTML(
"<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
)
gr.HTML(
"<h3><center>Visit <a href=\"https://github.com/hiyouga/LLaMA-Factory\" target=\"_blank\">"
"LLaMA Factory</a> for details.</center></h3>"
)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.all_elems["top"] = create_top() engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang") lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine) engine.manager.all_elems["train"] = create_train_tab(engine)
with gr.Tab("Evaluate"): with gr.Tab("Evaluate & Predict"):
engine.manager.all_elems["eval"] = create_eval_tab(engine) engine.manager.all_elems["eval"] = create_eval_tab(engine)
with gr.Tab("Chat"): with gr.Tab("Chat"):
engine.manager.all_elems["infer"] = create_infer_tab(engine) engine.manager.all_elems["infer"] = create_infer_tab(engine)
if not demo_mode:
with gr.Tab("Export"): with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine) engine.manager.all_elems["export"] = create_export_tab(engine)

View File

@ -659,6 +659,10 @@ ALERTS = {
"en": "Failed.", "en": "Failed.",
"zh": "训练出错。" "zh": "训练出错。"
}, },
"err_demo": {
"en": "Training is unavailable in demo mode, duplicate the space to a private one first.",
"zh": "展示模式不支持训练,请先复制到私人空间。"
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……" "zh": "训练中断,正在等待线程结束……"

View File

@ -4,7 +4,7 @@ import logging
import gradio as gr import gradio as gr
from threading import Thread from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
@ -24,13 +24,13 @@ if TYPE_CHECKING:
class Runner: class Runner:
def __init__(self, manager: "Manager") -> None: def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode
""" Resume """ """ Resume """
self.thread: "Thread" = None self.thread: "Thread" = None
self.do_train = True self.do_train = True
self.running_data: Dict["Component", Any] = None self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """ """ State """
self.aborted = False self.aborted = False
self.running = False self.running = False
@ -46,9 +46,8 @@ class Runner:
def set_abort(self) -> None: def set_abort(self) -> None:
self.aborted = True self.aborted = True
self.running = False
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str: def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)] get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset") dataset = get("train.dataset") if do_train else get("eval.dataset")
@ -65,6 +64,9 @@ class Runner:
if len(dataset) == 0: if len(dataset) == 0:
return ALERTS["err_no_dataset"][lang] return ALERTS["err_no_dataset"][lang]
if self.demo_mode and (not from_preview):
return ALERTS["err_demo"][lang]
self.aborted = False self.aborted = False
self.logger_handler.reset() self.logger_handler.reset()
self.trainer_callback = LogCallback(self) self.trainer_callback = LogCallback(self)
@ -72,6 +74,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
self.thread = None self.thread = None
self.running_data = None
self.running = False self.running = False
torch_gc() torch_gc()
if self.aborted: if self.aborted:
@ -84,9 +87,9 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([ checkpoint_dir = ",".join([get_save_dir(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") get("top.model_name"), get("top.finetuning_type"), ckpt
]) ) for ckpt in get("top.checkpoints")])
else: else:
checkpoint_dir = None checkpoint_dir = None
@ -136,7 +139,10 @@ class Runner:
args["upcast_layernorm"] = True args["upcast_layernorm"] = True
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")) args["reward_model"] = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
)
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
if args["stage"] == "dpo": if args["stage"] == "dpo":
args["dpo_beta"] = get("train.dpo_beta") args["dpo_beta"] = get("train.dpo_beta")
@ -154,9 +160,9 @@ class Runner:
user_config = load_config() user_config = load_config()
if get("top.checkpoints"): if get("top.checkpoints"):
checkpoint_dir = ",".join([ checkpoint_dir = ",".join([get_save_dir(
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") get("top.model_name"), get("top.finetuning_type"), ckpt
]) ) for ckpt in get("top.checkpoints")])
output_dir = get_save_dir( output_dir = get_save_dir(
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints")) get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
) )
@ -196,7 +202,7 @@ class Runner:
return args return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train) error = self._initialize(data, do_train, from_preview=True)
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error, gr.update(visible=False) yield error, gr.update(visible=False)
@ -205,16 +211,14 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False) yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train) error = self._initialize(data, do_train, from_preview=False)
if error: if error:
gr.Warning(error) gr.Warning(error)
yield error, gr.update(visible=False) yield error, gr.update(visible=False)
else: else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.running = True
self.do_train, self.running_data = do_train, data self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs) self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start() self.thread.start()
yield from self.monitor() yield from self.monitor()
@ -232,7 +236,12 @@ class Runner:
yield from self._launch(data, do_train=False) yield from self._launch(data, do_train=False)
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"] get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
while self.thread.is_alive(): while self.thread.is_alive():
time.sleep(2) time.sleep(2)
if self.aborted: if self.aborted:

View File

@ -1,17 +1,20 @@
import os import os
import json import json
import gradio as gr import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir from llmtuner.webui.common import get_save_dir
if TYPE_CHECKING: if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps: if not callback.max_steps:
@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result) return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model: if not base_model:
return return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")

View File

@ -7,12 +7,13 @@ import fire
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import Optional
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.model import get_train_args, load_model_and_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@ -24,12 +25,14 @@ def calculate_lr(
dataset: str, dataset: str,
cutoff_len: int, # i.e. maximum input length during training cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool # mistral model uses a smaller learning rate is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data"
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft", stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
dataset=dataset, dataset=dataset,
dataset_dir=dataset_dir,
template="default", template="default",
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
output_dir="dummy_dir" output_dir="dummy_dir"