mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
parent
fa06b168ab
commit
6261fb362a
19
README.md
19
README.md
@ -95,7 +95,7 @@ huggingface-cli login
|
|||||||
- Python 3.8+ and PyTorch 1.13.1+
|
- Python 3.8+ and PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
- jieba, rouge-chinese and nltk (used at evaluation)
|
- jieba, rouge-chinese and nltk (used at evaluation)
|
||||||
- gradio and mdtex2html (used in web_demo.py)
|
- gradio and matplotlib (used in web_demo.py)
|
||||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
||||||
|
|
||||||
And **powerful GPUs**!
|
And **powerful GPUs**!
|
||||||
@ -137,7 +137,8 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \
|
|||||||
### (Continually) Pre-Training
|
### (Continually) Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
@ -158,7 +159,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
|||||||
### Supervised Fine-Tuning
|
### Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
@ -179,7 +181,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
|||||||
### Reward Model Training
|
### Reward Model Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
@ -199,7 +202,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
|||||||
### PPO Training (RLHF)
|
### PPO Training (RLHF)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
@ -222,7 +226,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate config # configure the environment
|
accelerate config # configure the environment
|
||||||
accelerate launch src/train_XX.py # arguments (same as above)
|
accelerate launch src/train_bash.py # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
||||||
@ -256,7 +260,8 @@ use_cpu: false
|
|||||||
### Evaluation (BLEU and ROUGE_CHINESE)
|
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
|
3
pyproject.toml
Normal file
3
pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
@ -8,8 +8,9 @@ sentencepiece
|
|||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
gradio
|
gradio>=3.36.0
|
||||||
mdtex2html
|
|
||||||
uvicorn
|
uvicorn
|
||||||
|
pydantic==1.10.7
|
||||||
fastapi
|
fastapi
|
||||||
sse-starlette
|
sse-starlette
|
||||||
|
matplotlib
|
||||||
|
55
setup.py
Normal file
55
setup.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
|
def get_version():
|
||||||
|
with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
|
||||||
|
version, = re.findall(pattern, file_content)
|
||||||
|
return version
|
||||||
|
|
||||||
|
|
||||||
|
def get_requires():
|
||||||
|
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="llmtuner",
|
||||||
|
version=get_version(),
|
||||||
|
author="hiyouga",
|
||||||
|
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||||
|
description="Easy-to-use fine-tuning framework using PEFT",
|
||||||
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||||
|
license="Apache 2.0 License",
|
||||||
|
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
|
||||||
|
package_dir={"": "src"},
|
||||||
|
packages=find_packages("src"),
|
||||||
|
python_requires=">=3.8.0",
|
||||||
|
install_requires=get_requires(),
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Intended Audience :: Education",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
218
src/api_demo.py
218
src/api_demo.py
@ -4,225 +4,11 @@
|
|||||||
# Visit http://localhost:8000/docs for document.
|
# Visit http://localhost:8000/docs for document.
|
||||||
|
|
||||||
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from threading import Thread
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from transformers import TextIteratorStreamer
|
|
||||||
from sse_starlette import EventSourceResponse
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
|
|
||||||
from utils import (
|
from llmtuner import create_app
|
||||||
Template,
|
|
||||||
load_pretrained,
|
|
||||||
prepare_infer_args,
|
|
||||||
get_logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI): # collects GPU memory
|
|
||||||
yield
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: Optional[str] = "model"
|
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
|
||||||
owned_by: Optional[str] = "owner"
|
|
||||||
root: Optional[str] = None
|
|
||||||
parent: Optional[str] = None
|
|
||||||
permission: Optional[list] = []
|
|
||||||
|
|
||||||
|
|
||||||
class ModelList(BaseModel):
|
|
||||||
object: Optional[str] = "list"
|
|
||||||
data: Optional[List[ModelCard]] = []
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: Literal["user", "assistant", "system"]
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
|
||||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
|
||||||
content: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: List[ChatMessage]
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
n: Optional[int] = 1
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
message: ChatMessage
|
|
||||||
finish_reason: Literal["stop", "length"]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
delta: DeltaMessage
|
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseUsage(BaseModel):
|
|
||||||
prompt_tokens: int
|
|
||||||
completion_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
|
||||||
id: Optional[str] = "chatcmpl-default"
|
|
||||||
object: Literal["chat.completion"]
|
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[ChatCompletionResponseChoice]
|
|
||||||
usage: ChatCompletionResponseUsage
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
|
||||||
id: Optional[str] = "chatcmpl-default"
|
|
||||||
object: Literal["chat.completion.chunk"]
|
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
|
||||||
async def list_models():
|
|
||||||
global model_args
|
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
|
||||||
return ModelList(data=[model_card])
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
|
||||||
global model, tokenizer, source_prefix, generating_args
|
|
||||||
|
|
||||||
if request.messages[-1].role != "user":
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
|
||||||
query = request.messages[-1].content
|
|
||||||
|
|
||||||
prev_messages = request.messages[:-1]
|
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
|
||||||
prefix = prev_messages.pop(0).content
|
|
||||||
else:
|
|
||||||
prefix = source_prefix
|
|
||||||
|
|
||||||
history = []
|
|
||||||
if len(prev_messages) % 2 == 0:
|
|
||||||
for i in range(0, len(prev_messages), 2):
|
|
||||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
|
||||||
|
|
||||||
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
|
||||||
inputs = inputs.to(model.device)
|
|
||||||
|
|
||||||
gen_kwargs = generating_args.to_dict()
|
|
||||||
gen_kwargs.update({
|
|
||||||
"input_ids": inputs["input_ids"],
|
|
||||||
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
|
||||||
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
|
||||||
"logits_processor": get_logits_processor()
|
|
||||||
})
|
|
||||||
|
|
||||||
if request.max_tokens:
|
|
||||||
gen_kwargs.pop("max_length", None)
|
|
||||||
gen_kwargs["max_new_tokens"] = request.max_tokens
|
|
||||||
|
|
||||||
if request.stream:
|
|
||||||
generate = predict(gen_kwargs, request.model)
|
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
|
||||||
|
|
||||||
generation_output = model.generate(**gen_kwargs)
|
|
||||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
|
||||||
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
|
||||||
prompt_tokens=len(inputs["input_ids"][0]),
|
|
||||||
completion_tokens=len(outputs),
|
|
||||||
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
|
||||||
index=0,
|
|
||||||
message=ChatMessage(role="assistant", content=response),
|
|
||||||
finish_reason="stop"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
|
||||||
|
|
||||||
|
|
||||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|
||||||
global model, tokenizer
|
|
||||||
|
|
||||||
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
|
||||||
gen_kwargs["streamer"] = streamer
|
|
||||||
|
|
||||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(role="assistant"),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
for new_text in streamer:
|
|
||||||
if len(new_text) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(content=new_text),
|
|
||||||
finish_reason=None
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(),
|
|
||||||
finish_reason="stop"
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|
||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
|
||||||
yield "[DONE]"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
app = create_app()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
|
||||||
|
|
||||||
prompt_template = Template(data_args.prompt_template)
|
|
||||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
@ -2,21 +2,15 @@
|
|||||||
# Implements stream chat in command line for fine-tuned models.
|
# Implements stream chat in command line for fine-tuned models.
|
||||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
|
|
||||||
from utils import (
|
|
||||||
Template,
|
|
||||||
load_pretrained,
|
|
||||||
prepare_infer_args,
|
|
||||||
get_logits_processor
|
|
||||||
)
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
|
||||||
|
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = Template(data_args.prompt_template)
|
||||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||||
|
@ -2,14 +2,12 @@
|
|||||||
# Exports the fine-tuned model.
|
# Exports the fine-tuned model.
|
||||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||||
|
|
||||||
|
from llmtuner import get_train_args, load_model_and_tokenizer
|
||||||
from utils import load_pretrained, prepare_args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
model_args, _, training_args, finetuning_args, _ = get_train_args()
|
||||||
model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||||
|
7
src/llmtuner/__init__.py
Normal file
7
src/llmtuner/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from llmtuner.api import create_app
|
||||||
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
|
from llmtuner.extras.template import Template
|
||||||
|
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = "0.0.9"
|
1
src/llmtuner/api/__init__.py
Normal file
1
src/llmtuner/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.api.app import create_app
|
152
src/llmtuner/api/app.py
Normal file
152
src/llmtuner/api/app.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import uvicorn
|
||||||
|
from threading import Thread
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||||
|
from llmtuner.extras.misc import get_logits_processor, torch_gc
|
||||||
|
from llmtuner.extras.template import Template
|
||||||
|
from llmtuner.api.protocol import (
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
ChatMessage,
|
||||||
|
DeltaMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionResponseUsage
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI): # collects GPU memory
|
||||||
|
yield
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app():
|
||||||
|
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
|
prompt_template = Template(data_args.prompt_template)
|
||||||
|
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
async def list_models():
|
||||||
|
global model_args
|
||||||
|
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||||
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
if request.messages[-1].role != "user":
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
|
query = request.messages[-1].content
|
||||||
|
|
||||||
|
prev_messages = request.messages[:-1]
|
||||||
|
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||||
|
prefix = prev_messages.pop(0).content
|
||||||
|
else:
|
||||||
|
prefix = source_prefix
|
||||||
|
|
||||||
|
history = []
|
||||||
|
if len(prev_messages) % 2 == 0:
|
||||||
|
for i in range(0, len(prev_messages), 2):
|
||||||
|
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||||
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
|
|
||||||
|
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||||
|
inputs = inputs.to(model.device)
|
||||||
|
|
||||||
|
gen_kwargs = generating_args.to_dict()
|
||||||
|
gen_kwargs.update({
|
||||||
|
"input_ids": inputs["input_ids"],
|
||||||
|
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
|
||||||
|
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
|
||||||
|
"logits_processor": get_logits_processor()
|
||||||
|
})
|
||||||
|
|
||||||
|
if request.max_tokens:
|
||||||
|
gen_kwargs.pop("max_length", None)
|
||||||
|
gen_kwargs["max_new_tokens"] = request.max_tokens
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
generate = predict(gen_kwargs, request.model)
|
||||||
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
|
generation_output = model.generate(**gen_kwargs)
|
||||||
|
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
|
||||||
|
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
usage = ChatCompletionResponseUsage(
|
||||||
|
prompt_tokens=len(inputs["input_ids"][0]),
|
||||||
|
completion_tokens=len(outputs),
|
||||||
|
total_tokens=len(inputs["input_ids"][0]) + len(outputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=response),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
||||||
|
|
||||||
|
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
||||||
|
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant"),
|
||||||
|
finish_reason=None
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
for new_text in streamer:
|
||||||
|
if len(new_text) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(content=new_text),
|
||||||
|
finish_reason=None
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(),
|
||||||
|
finish_reason="stop"
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = create_app()
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
73
src/llmtuner/api/protocol.py
Normal file
73
src/llmtuner/api/protocol.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import time
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Optional[str] = "model"
|
||||||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: Optional[str] = "owner"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
permission: Optional[list] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: Optional[str] = "list"
|
||||||
|
data: Optional[List[ModelCard]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: Literal["user", "assistant", "system"]
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = 1
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Literal["stop", "length"]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
delta: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: Optional[str] = "chatcmpl-default"
|
||||||
|
object: Literal["chat.completion"]
|
||||||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseChoice]
|
||||||
|
usage: ChatCompletionResponseUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: Optional[str] = "chatcmpl-default"
|
||||||
|
object: Literal["chat.completion.chunk"]
|
||||||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[ChatCompletionResponseStreamChoice]
|
2
src/llmtuner/dsets/__init__.py
Normal file
2
src/llmtuner/dsets/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from llmtuner.dsets.loader import get_dataset
|
||||||
|
from llmtuner.dsets.preprocess import preprocess_dataset
|
63
src/llmtuner/dsets/callbacks.py
Normal file
63
src/llmtuner/dsets/callbacks.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
|
def __init__(self, runner=None):
|
||||||
|
self.runner = runner
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.tracker = {}
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||||
|
might take several inputs.
|
||||||
|
"""
|
||||||
|
if self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of an substep during gradient accumulation.
|
||||||
|
"""
|
||||||
|
if self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||||
|
r"""
|
||||||
|
Event called after logging the last logs.
|
||||||
|
"""
|
||||||
|
if "loss" not in state.log_history[-1]:
|
||||||
|
return
|
||||||
|
cur_time = time.time()
|
||||||
|
cur_steps = state.log_history[-1].get("step")
|
||||||
|
elapsed_time = cur_time - self.start_time
|
||||||
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||||
|
remaining_steps = state.max_steps - cur_steps
|
||||||
|
remaining_time = remaining_steps * avg_time_per_step
|
||||||
|
self.tracker = {
|
||||||
|
"current_steps": cur_steps,
|
||||||
|
"total_steps": state.max_steps,
|
||||||
|
"loss": state.log_history[-1].get("loss", None),
|
||||||
|
"reward": state.log_history[-1].get("reward", None),
|
||||||
|
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||||
|
"epoch": state.log_history[-1].get("epoch", None),
|
||||||
|
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||||
|
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||||
|
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||||
|
}
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(self.tracker) + "\n")
|
106
src/llmtuner/dsets/loader.py
Normal file
106
src/llmtuner/dsets/loader.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
model_args: ModelArguments,
|
||||||
|
data_args: DataArguments
|
||||||
|
) -> Dataset:
|
||||||
|
|
||||||
|
def checksum(file_path, hash):
|
||||||
|
with open(file_path, "rb") as datafile:
|
||||||
|
binary_data = datafile.read()
|
||||||
|
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||||
|
if sha1 != hash:
|
||||||
|
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||||
|
|
||||||
|
ext2type = {
|
||||||
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
max_samples = data_args.max_samples
|
||||||
|
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||||
|
|
||||||
|
for dataset_attr in data_args.dataset_list:
|
||||||
|
|
||||||
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
|
if dataset_attr.load_from == "hf_hub":
|
||||||
|
data_path = dataset_attr.dataset_name
|
||||||
|
data_files = None
|
||||||
|
elif dataset_attr.load_from == "script":
|
||||||
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
|
data_files = None
|
||||||
|
elif dataset_attr.load_from == "file":
|
||||||
|
data_path = None
|
||||||
|
data_files: List[str] = []
|
||||||
|
|
||||||
|
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
|
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||||
|
|
||||||
|
if data_path is None:
|
||||||
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||||
|
else:
|
||||||
|
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||||
|
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||||
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||||
|
else:
|
||||||
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
|
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||||
|
|
||||||
|
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||||
|
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||||
|
else:
|
||||||
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
raw_datasets = load_dataset(
|
||||||
|
data_path,
|
||||||
|
data_files=data_files,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None
|
||||||
|
)
|
||||||
|
dataset = raw_datasets[data_args.split]
|
||||||
|
|
||||||
|
if max_samples is not None:
|
||||||
|
max_samples_temp = min(len(dataset), max_samples)
|
||||||
|
dataset = dataset.select(range(max_samples_temp))
|
||||||
|
|
||||||
|
dummy_data = [None] * len(dataset)
|
||||||
|
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||||
|
for column_name, target_name in [
|
||||||
|
("prompt_column", "prompt"),
|
||||||
|
("query_column", "query"),
|
||||||
|
("response_column", "response"),
|
||||||
|
("history_column", "history")
|
||||||
|
]: # every dataset will have 4 columns same as each other
|
||||||
|
if getattr(dataset_attr, column_name) != target_name:
|
||||||
|
if getattr(dataset_attr, column_name):
|
||||||
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||||
|
else: # None or empty string
|
||||||
|
dataset = dataset.add_column(target_name, dummy_data)
|
||||||
|
dataset = dataset.add_column("prefix", prefix_data)
|
||||||
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
|
if len(data_args.dataset_list) == 1:
|
||||||
|
all_datasets = all_datasets[0]
|
||||||
|
else:
|
||||||
|
all_datasets = concatenate_datasets(all_datasets)
|
||||||
|
|
||||||
|
return all_datasets
|
172
src/llmtuner/dsets/preprocess.py
Normal file
172
src/llmtuner/dsets/preprocess.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from itertools import chain
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.template import Template
|
||||||
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_dataset(
|
||||||
|
dataset: Dataset,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
data_args: DataArguments,
|
||||||
|
training_args: Seq2SeqTrainingArguments,
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
|
) -> Dataset:
|
||||||
|
|
||||||
|
column_names = list(dataset.column_names)
|
||||||
|
prompt_template = Template(data_args.prompt_template)
|
||||||
|
|
||||||
|
# support question with a single answer or multiple answers
|
||||||
|
def get_dialog(examples):
|
||||||
|
for i in range(len(examples["prompt"])):
|
||||||
|
if examples["prompt"][i] and examples["response"][i]:
|
||||||
|
query, answer = examples["prompt"][i], examples["response"][i]
|
||||||
|
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||||
|
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||||
|
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||||
|
yield dialog
|
||||||
|
|
||||||
|
def preprocess_pretrain_dataset(examples):
|
||||||
|
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||||
|
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
|
concatenated_ids = list(chain(*text_ids))
|
||||||
|
total_length = len(concatenated_ids)
|
||||||
|
block_size = data_args.max_source_length - 1
|
||||||
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||||
|
total_length = (total_length // block_size) * block_size
|
||||||
|
# split by chunks of max_source_length
|
||||||
|
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||||
|
for i in range(0, total_length, block_size)]
|
||||||
|
return {
|
||||||
|
"input_ids": result,
|
||||||
|
"labels": result.copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess_supervised_dataset(examples):
|
||||||
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
|
# for input with history, we build multiple input-label pairs just like:
|
||||||
|
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||||
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
|
max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
|
for dialog in get_dialog(examples):
|
||||||
|
input_ids, labels = [], []
|
||||||
|
|
||||||
|
for i in range(len(dialog) // 2):
|
||||||
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||||
|
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(source_ids) > data_args.max_source_length:
|
||||||
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
|
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||||
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
|
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||||
|
break
|
||||||
|
|
||||||
|
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||||
|
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(input_ids)
|
||||||
|
model_inputs["labels"].append(labels)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_unsupervised_dataset(examples):
|
||||||
|
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||||
|
model_inputs = {"input_ids": [], "labels": []}
|
||||||
|
|
||||||
|
for dialog in get_dialog(examples):
|
||||||
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
|
||||||
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||||
|
|
||||||
|
if len(source_ids) > data_args.max_source_length:
|
||||||
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
|
if len(target_ids) > data_args.max_target_length:
|
||||||
|
target_ids = target_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
|
model_inputs["input_ids"].append(source_ids)
|
||||||
|
model_inputs["labels"].append(target_ids)
|
||||||
|
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def preprocess_pairwise_dataset(examples):
|
||||||
|
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||||
|
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||||
|
for dialog in get_dialog(examples):
|
||||||
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||||
|
|
||||||
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||||
|
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||||
|
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(source_ids) > data_args.max_source_length:
|
||||||
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
|
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||||
|
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||||
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||||
|
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||||
|
|
||||||
|
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||||
|
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
model_inputs["accept_ids"].append(accept_ids)
|
||||||
|
model_inputs["reject_ids"].append(reject_ids)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def print_supervised_dataset_example(example):
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
|
print("labels:\n{}".format(
|
||||||
|
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||||
|
skip_special_tokens=False)
|
||||||
|
))
|
||||||
|
|
||||||
|
def print_pairwise_dataset_example(example):
|
||||||
|
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||||
|
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||||
|
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||||
|
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
def print_unsupervised_dataset_example(example):
|
||||||
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
|
if stage == "pt":
|
||||||
|
preprocess_function = preprocess_pretrain_dataset
|
||||||
|
elif stage == "sft":
|
||||||
|
preprocess_function = preprocess_unsupervised_dataset \
|
||||||
|
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||||
|
elif stage == "rm":
|
||||||
|
preprocess_function = preprocess_pairwise_dataset
|
||||||
|
elif stage == "ppo":
|
||||||
|
preprocess_function = preprocess_unsupervised_dataset
|
||||||
|
|
||||||
|
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||||
|
dataset = dataset.map(
|
||||||
|
preprocess_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=column_names,
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
desc="Running tokenizer on dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
if stage == "pt":
|
||||||
|
print_unsupervised_dataset_example(dataset[0])
|
||||||
|
elif stage == "sft":
|
||||||
|
print_supervised_dataset_example(dataset[0])
|
||||||
|
elif stage == "rm":
|
||||||
|
print_pairwise_dataset_example(dataset[0])
|
||||||
|
elif stage == "ppo":
|
||||||
|
print_unsupervised_dataset_example(dataset[0])
|
||||||
|
|
||||||
|
return dataset
|
72
src/llmtuner/extras/callbacks.py
Normal file
72
src/llmtuner/extras/callbacks.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerControl,
|
||||||
|
TrainerState,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
|
def __init__(self, runner=None):
|
||||||
|
self.runner = runner
|
||||||
|
self.tracker = {}
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the beginning of training.
|
||||||
|
"""
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||||
|
might take several inputs.
|
||||||
|
"""
|
||||||
|
if self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of an substep during gradient accumulation.
|
||||||
|
"""
|
||||||
|
if self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||||
|
r"""
|
||||||
|
Event called after logging the last logs.
|
||||||
|
"""
|
||||||
|
if "current_steps" not in state.log_history[-1]:
|
||||||
|
return
|
||||||
|
cur_time = time.time()
|
||||||
|
cur_steps = state.log_history[-1].get("step")
|
||||||
|
elapsed_time = cur_time - self.start_time
|
||||||
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||||
|
remaining_steps = state.max_steps - cur_steps
|
||||||
|
remaining_time = remaining_steps * avg_time_per_step
|
||||||
|
self.tracker = {
|
||||||
|
"current_steps": cur_steps,
|
||||||
|
"total_steps": state.max_steps,
|
||||||
|
"loss": state.log_history[-1].get("loss", None),
|
||||||
|
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
||||||
|
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
||||||
|
"reward": state.log_history[-1].get("reward", None),
|
||||||
|
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||||
|
"epoch": state.log_history[-1].get("epoch", None),
|
||||||
|
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||||
|
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||||
|
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||||
|
}
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(self.tracker) + "\n")
|
7
src/llmtuner/extras/constants.py
Normal file
7
src/llmtuner/extras/constants.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||||
|
|
||||||
|
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||||
|
|
||||||
|
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
18
src/llmtuner/extras/logging.py
Normal file
18
src/llmtuner/extras/logging.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: str) -> logging.Logger:
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S"
|
||||||
|
)
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
return logger
|
105
src/llmtuner/extras/misc.py
Normal file
105
src/llmtuner/extras/misc.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import torch
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.generation.utils import LogitsProcessorList
|
||||||
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter:
|
||||||
|
r"""
|
||||||
|
Computes and stores the average and current value.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
# Avoid runtime error in model.generate(do_sample=True).
|
||||||
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
|
scores.zero_()
|
||||||
|
scores[..., 0] = 1.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def get_logits_processor() -> LogitsProcessorList:
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
|
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||||
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
|
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||||
|
trainable_params, all_param = 0, 0
|
||||||
|
for param in model.parameters():
|
||||||
|
num_params = param.numel()
|
||||||
|
# if using DS Zero 3 and the weights are initialized empty
|
||||||
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||||
|
num_params = param.ds_numel
|
||||||
|
all_param += num_params
|
||||||
|
if param.requires_grad:
|
||||||
|
trainable_params += num_params
|
||||||
|
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
|
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||||
|
|
||||||
|
|
||||||
|
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||||
|
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||||
|
def prepare_model_for_training(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
finetuning_type: str,
|
||||||
|
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||||
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
|
) -> PreTrainedModel:
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
|
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
|
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||||
|
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||||
|
input_dtype = output_embedding_layer.weight.dtype
|
||||||
|
|
||||||
|
class CastOutputToFloat(torch.nn.Sequential):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||||
|
|
||||||
|
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def torch_gc() -> None:
|
||||||
|
r"""
|
||||||
|
Collects GPU memory.
|
||||||
|
"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
50
src/llmtuner/extras/ploting.py
Normal file
50
src/llmtuner/extras/ploting.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from typing import List, Optional
|
||||||
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||||
|
r"""
|
||||||
|
EMA implementation according to TensorBoard.
|
||||||
|
"""
|
||||||
|
last = scalars[0]
|
||||||
|
smoothed = list()
|
||||||
|
for next_val in scalars:
|
||||||
|
smoothed_val = last * weight + (1 - weight) * next_val
|
||||||
|
smoothed.append(smoothed_val)
|
||||||
|
last = smoothed_val
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
|
||||||
|
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||||
|
|
||||||
|
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
steps, metrics = [], []
|
||||||
|
for i in range(len(data["log_history"])):
|
||||||
|
if key in data["log_history"][i]:
|
||||||
|
steps.append(data["log_history"][i]["step"])
|
||||||
|
metrics.append(data["log_history"][i][key])
|
||||||
|
|
||||||
|
if len(metrics) == 0:
|
||||||
|
logger.warning(f"No metric {key} to plot.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||||
|
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||||
|
plt.title("training {} of {}".format(key, save_dictionary))
|
||||||
|
plt.xlabel("step")
|
||||||
|
plt.ylabel(key)
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||||
|
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
49
src/llmtuner/extras/save_and_load.py
Normal file
49
src/llmtuner/extras/save_and_load.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
from transformers.modeling_utils import load_sharded_checkpoint
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
filtered_state_dict = {}
|
||||||
|
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
if v.requires_grad:
|
||||||
|
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||||
|
|
||||||
|
return filtered_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||||
|
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||||
|
if os.path.exists(weights_file):
|
||||||
|
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||||
|
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||||
|
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||||
|
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||||
|
else:
|
||||||
|
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||||
|
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||||
|
if not os.path.exists(valuehead_file):
|
||||||
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||||
|
return False
|
||||||
|
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||||
|
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||||
|
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||||
|
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||||
|
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||||
|
return True
|
5
src/llmtuner/hparams/__init__.py
Normal file
5
src/llmtuner/hparams/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .data_args import DataArguments
|
||||||
|
from .finetuning_args import FinetuningArguments
|
||||||
|
from .general_args import GeneralArguments
|
||||||
|
from .generating_args import GeneratingArguments
|
||||||
|
from .model_args import ModelArguments
|
119
src/llmtuner/hparams/data_args.py
Normal file
119
src/llmtuner/hparams/data_args.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetAttr:
|
||||||
|
|
||||||
|
load_from: str
|
||||||
|
dataset_name: Optional[str] = None
|
||||||
|
dataset_sha1: Optional[str] = None
|
||||||
|
source_prefix: Optional[str] = None
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.dataset_name
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.prompt_column = "instruction"
|
||||||
|
self.query_column = "input"
|
||||||
|
self.response_column = "output"
|
||||||
|
self.history_column = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
|
"""
|
||||||
|
dataset: Optional[str] = field(
|
||||||
|
default="alpaca_zh",
|
||||||
|
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
||||||
|
)
|
||||||
|
dataset_dir: Optional[str] = field(
|
||||||
|
default="data",
|
||||||
|
metadata={"help": "The name of the folder containing datasets."}
|
||||||
|
)
|
||||||
|
split: Optional[str] = field(
|
||||||
|
default="train",
|
||||||
|
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||||
|
)
|
||||||
|
overwrite_cache: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||||
|
)
|
||||||
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||||
|
)
|
||||||
|
max_source_length: Optional[int] = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||||
|
)
|
||||||
|
max_target_length: Optional[int] = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||||
|
)
|
||||||
|
max_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||||
|
)
|
||||||
|
eval_num_beams: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
||||||
|
)
|
||||||
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||||
|
)
|
||||||
|
source_prefix: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||||
|
)
|
||||||
|
dev_ratio: Optional[float] = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||||
|
)
|
||||||
|
prompt_template: Optional[str] = field(
|
||||||
|
default="default",
|
||||||
|
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||||
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||||
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
|
if self.source_prefix is not None:
|
||||||
|
prefix_list = self.source_prefix.split("|")
|
||||||
|
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||||
|
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||||
|
else:
|
||||||
|
prefix_list = [None] * len(dataset_names)
|
||||||
|
|
||||||
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
|
for i, name in enumerate(dataset_names):
|
||||||
|
if name not in dataset_info:
|
||||||
|
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||||
|
|
||||||
|
if "hf_hub_url" in dataset_info[name]:
|
||||||
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
elif "script_url" in dataset_info[name]:
|
||||||
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||||
|
else:
|
||||||
|
dataset_attr = DatasetAttr(
|
||||||
|
"file",
|
||||||
|
dataset_name=dataset_info[name]["file_name"],
|
||||||
|
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_attr.source_prefix = prefix_list[i]
|
||||||
|
|
||||||
|
if "columns" in dataset_info[name]:
|
||||||
|
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||||
|
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||||
|
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||||
|
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||||
|
|
||||||
|
self.dataset_list.append(dataset_attr)
|
78
src/llmtuner/hparams/finetuning_args.py
Normal file
78
src/llmtuner/hparams/finetuning_args.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import json
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
|
"""
|
||||||
|
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||||
|
default="lora",
|
||||||
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
|
)
|
||||||
|
num_hidden_layers: Optional[int] = field(
|
||||||
|
default=32,
|
||||||
|
metadata={"help": "Number of decoder blocks in the model. \
|
||||||
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||||
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||||
|
Falcon choices: [\"32\", \"60\"], \
|
||||||
|
Baichuan choices: [\"32\"]"}
|
||||||
|
)
|
||||||
|
num_layer_trainable: Optional[int] = field(
|
||||||
|
default=3,
|
||||||
|
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||||
|
)
|
||||||
|
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||||
|
default="mlp",
|
||||||
|
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||||
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||||
|
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||||
|
)
|
||||||
|
lora_rank: Optional[int] = field(
|
||||||
|
default=8,
|
||||||
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||||
|
)
|
||||||
|
lora_alpha: Optional[float] = field(
|
||||||
|
default=32.0,
|
||||||
|
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||||
|
)
|
||||||
|
lora_dropout: Optional[float] = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||||
|
)
|
||||||
|
lora_target: Optional[str] = field(
|
||||||
|
default="q_proj,v_proj",
|
||||||
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||||
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||||
|
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||||
|
|
||||||
|
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||||
|
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||||
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
|
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||||
|
|
||||||
|
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||||
|
|
||||||
|
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||||
|
|
||||||
|
def save_to_json(self, json_path: str):
|
||||||
|
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json_string)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_json(cls, json_path: str):
|
||||||
|
"""Creates an instance from the content of `json_path`."""
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
return cls(**json.loads(text))
|
13
src/llmtuner/hparams/general_args.py
Normal file
13
src/llmtuner/hparams/general_args.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneralArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
|
"""
|
||||||
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||||
|
default="sft",
|
||||||
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
|
)
|
51
src/llmtuner/hparams/generating_args.py
Normal file
51
src/llmtuner/hparams/generating_args.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneratingArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to specify the decoding parameters.
|
||||||
|
"""
|
||||||
|
do_sample: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
||||||
|
)
|
||||||
|
temperature: Optional[float] = field(
|
||||||
|
default=0.95,
|
||||||
|
metadata={"help": "The value used to modulate the next token probabilities."}
|
||||||
|
)
|
||||||
|
top_p: Optional[float] = field(
|
||||||
|
default=0.7,
|
||||||
|
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
||||||
|
)
|
||||||
|
top_k: Optional[int] = field(
|
||||||
|
default=50,
|
||||||
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
||||||
|
)
|
||||||
|
num_beams: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||||
|
)
|
||||||
|
max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||||
|
)
|
||||||
|
max_new_tokens: Optional[int] = field(
|
||||||
|
default=512,
|
||||||
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||||
|
)
|
||||||
|
repetition_penalty: Optional[float] = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
||||||
|
)
|
||||||
|
length_penalty: Optional[float] = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
args = asdict(self)
|
||||||
|
if args.get("max_new_tokens", None):
|
||||||
|
args.pop("max_length", None)
|
||||||
|
return args
|
72
src/llmtuner/hparams/model_args.py
Normal file
72
src/llmtuner/hparams/model_args.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Literal, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
|
"""
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||||
|
)
|
||||||
|
use_fast_tokenizer: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||||
|
)
|
||||||
|
use_auth_token: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||||
|
)
|
||||||
|
model_revision: Optional[str] = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||||
|
)
|
||||||
|
padding_side: Optional[Literal["left", "right"]] = field(
|
||||||
|
default="left",
|
||||||
|
metadata={"help": "The side on which the model should have padding applied."}
|
||||||
|
)
|
||||||
|
quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the model."}
|
||||||
|
)
|
||||||
|
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
||||||
|
default="nf4",
|
||||||
|
metadata={"help": "Quantization data type to use in int4 training."}
|
||||||
|
)
|
||||||
|
double_quantization: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||||
|
)
|
||||||
|
compute_dtype: Optional[torch.dtype] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||||
|
)
|
||||||
|
checkpoint_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||||
|
)
|
||||||
|
reward_model: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
resume_lora_training: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
|
)
|
||||||
|
plot_loss: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
|
if self.quantization_bit is not None:
|
||||||
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
5
src/llmtuner/tuner/__init__.py
Normal file
5
src/llmtuner/tuner/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.pt import run_pt
|
||||||
|
from llmtuner.tuner.sft import run_sft
|
||||||
|
from llmtuner.tuner.rm import run_rm
|
||||||
|
from llmtuner.tuner.ppo import run_ppo
|
2
src/llmtuner/tuner/core/__init__.py
Normal file
2
src/llmtuner/tuner/core/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||||
|
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
94
src/llmtuner/tuner/core/adapter.py
Normal file
94
src/llmtuner/tuner/core/adapter.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from peft import (
|
||||||
|
PeftModel,
|
||||||
|
TaskType,
|
||||||
|
LoraConfig,
|
||||||
|
get_peft_model
|
||||||
|
)
|
||||||
|
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.save_and_load import load_trainable_params
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_adapter(
|
||||||
|
model: PreTrainedModel,
|
||||||
|
model_args: ModelArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
is_trainable: bool,
|
||||||
|
is_mergeable: bool
|
||||||
|
) -> PreTrainedModel:
|
||||||
|
r"""
|
||||||
|
Initializes the adapters.
|
||||||
|
|
||||||
|
Support full-parameter, freeze and LoRA training.
|
||||||
|
|
||||||
|
Note that the trainable parameters must be cast to float32.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||||
|
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "full":
|
||||||
|
logger.info("Fine-tuning method: Full")
|
||||||
|
model = model.float()
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "freeze":
|
||||||
|
logger.info("Fine-tuning method: Freeze")
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||||
|
param.requires_grad_(False)
|
||||||
|
else:
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None:
|
||||||
|
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||||
|
|
||||||
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
logger.info("Fine-tuning method: LoRA")
|
||||||
|
latest_checkpoint = None
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None:
|
||||||
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||||
|
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||||
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||||
|
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||||
|
|
||||||
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||||
|
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
|
else:
|
||||||
|
checkpoints_to_merge = model_args.checkpoint_dir
|
||||||
|
|
||||||
|
for checkpoint in checkpoints_to_merge:
|
||||||
|
model = PeftModel.from_pretrained(model, checkpoint)
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
|
if len(checkpoints_to_merge) > 0:
|
||||||
|
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||||
|
|
||||||
|
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||||
|
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||||
|
|
||||||
|
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
|
r=finetuning_args.lora_rank,
|
||||||
|
lora_alpha=finetuning_args.lora_alpha,
|
||||||
|
lora_dropout=finetuning_args.lora_dropout,
|
||||||
|
target_modules=finetuning_args.lora_target
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None:
|
||||||
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
||||||
|
return model
|
151
src/llmtuner/tuner/core/loader.py
Normal file
151
src/llmtuner/tuner/core/loader.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from typing import Literal, Optional, Tuple
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BitsAndBytesConfig
|
||||||
|
)
|
||||||
|
from transformers.utils import check_min_version
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
|
||||||
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
check_min_version("4.29.1")
|
||||||
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
|
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||||
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||||
|
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(
|
||||||
|
model_args: ModelArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
is_trainable: Optional[bool] = False,
|
||||||
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||||
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||||
|
r"""
|
||||||
|
Loads pretrained model and tokenizer.
|
||||||
|
|
||||||
|
Support both training and inference.
|
||||||
|
"""
|
||||||
|
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||||
|
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||||
|
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||||
|
|
||||||
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||||
|
"RM and PPO training can only be performed with the LoRA method."
|
||||||
|
|
||||||
|
config_kwargs = {
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"cache_dir": model_args.cache_dir,
|
||||||
|
"revision": model_args.model_revision,
|
||||||
|
"use_auth_token": True if model_args.use_auth_token else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
padding_side=model_args.padding_side,
|
||||||
|
**config_kwargs
|
||||||
|
)
|
||||||
|
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||||
|
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
|
is_mergeable = True
|
||||||
|
|
||||||
|
# Quantization configurations (using bitsandbytes library).
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
if model_args.quantization_bit == 8:
|
||||||
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
|
config_kwargs["load_in_8bit"] = True
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_8bit=True,
|
||||||
|
llm_int8_threshold=6.0
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_args.quantization_bit == 4:
|
||||||
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
|
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||||
|
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||||
|
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||||
|
config_kwargs["load_in_4bit"] = True
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
|
)
|
||||||
|
|
||||||
|
is_mergeable = False
|
||||||
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||||
|
config_kwargs["device_map"] = "auto"
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||||
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
|
else:
|
||||||
|
model_to_load = model_args.model_name_or_path
|
||||||
|
|
||||||
|
# Load and prepare pretrained models (without valuehead).
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_to_load,
|
||||||
|
config=config,
|
||||||
|
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
**config_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register auto class to save the custom code files.
|
||||||
|
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
||||||
|
config.__class__.register_for_auto_class()
|
||||||
|
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||||
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
||||||
|
model.__class__.register_for_auto_class()
|
||||||
|
|
||||||
|
# Initialize adapters
|
||||||
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||||
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
|
|
||||||
|
if stage == "rm" or stage == "ppo": # add value head
|
||||||
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
|
|
||||||
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||||
|
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||||
|
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||||
|
model.v_head.load_state_dict({
|
||||||
|
"summary.weight": getattr(model, "reward_head_weight"),
|
||||||
|
"summary.bias": getattr(model, "reward_head_bias")
|
||||||
|
})
|
||||||
|
|
||||||
|
if stage == "ppo": # load reward model
|
||||||
|
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||||
|
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||||
|
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||||
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||||
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||||
|
|
||||||
|
if not is_trainable:
|
||||||
|
model.requires_grad_(False) # fix all model params
|
||||||
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||||
|
|
||||||
|
print_trainable_params(model)
|
||||||
|
|
||||||
|
return model, tokenizer
|
134
src/llmtuner/tuner/core/parser.py
Normal file
134
src/llmtuner/tuner/core/parser.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import datasets
|
||||||
|
import transformers
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.hparams import (
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
||||||
|
|
||||||
|
if args is not None:
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
if training_args.should_log:
|
||||||
|
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
|
||||||
|
log_level = training_args.get_process_log_level()
|
||||||
|
datasets.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||||
|
data_args.init_for_training()
|
||||||
|
|
||||||
|
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
|
||||||
|
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
||||||
|
|
||||||
|
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||||
|
"`predict_with_generate` cannot be set as True while training."
|
||||||
|
|
||||||
|
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||||
|
"Please enable `predict_with_generate` to save model predictions."
|
||||||
|
|
||||||
|
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||||
|
"Quantization is only compatible with the LoRA method."
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None:
|
||||||
|
if finetuning_args.finetuning_type != "lora":
|
||||||
|
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||||
|
else:
|
||||||
|
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||||
|
"Quantized model only accepts a single checkpoint."
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||||
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
|
if training_args.do_train and (not training_args.fp16):
|
||||||
|
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||||
|
|
||||||
|
if data_args.prompt_template == "default":
|
||||||
|
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||||
|
|
||||||
|
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||||
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||||
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
|
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None:
|
||||||
|
if training_args.fp16:
|
||||||
|
model_args.compute_dtype = torch.float16
|
||||||
|
elif training_args.bf16:
|
||||||
|
model_args.compute_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
model_args.compute_dtype = torch.float32
|
||||||
|
|
||||||
|
# Log on each process the small summary:
|
||||||
|
logger.info(
|
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||||
|
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
|
)
|
||||||
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
# Set seed before initializing model.
|
||||||
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
|
return model_args, data_args, training_args, finetuning_args, general_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_infer_args(
|
||||||
|
args: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||||
|
|
||||||
|
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
||||||
|
|
||||||
|
if args is not None:
|
||||||
|
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
|
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||||
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||||
|
"Quantization is only compatible with the LoRA method."
|
||||||
|
|
||||||
|
if model_args.checkpoint_dir is not None:
|
||||||
|
if finetuning_args.finetuning_type != "lora":
|
||||||
|
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||||
|
else:
|
||||||
|
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||||
|
"Quantized model only accepts a single checkpoint."
|
||||||
|
|
||||||
|
if data_args.prompt_template == "default":
|
||||||
|
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||||
|
|
||||||
|
return model_args, data_args, finetuning_args, generating_args
|
@ -1,76 +1,20 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
Seq2SeqTrainer,
|
|
||||||
TrainerCallback,
|
|
||||||
TrainerControl,
|
|
||||||
TrainerState,
|
|
||||||
TrainingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainer
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.modeling_utils import unwrap_model
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
|
||||||
from .config import FinetuningArguments
|
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
from .other import (
|
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
||||||
get_logger,
|
from llmtuner.hparams import FinetuningArguments
|
||||||
get_state_dict,
|
|
||||||
load_trainable_params,
|
|
||||||
load_valuehead_params,
|
|
||||||
FINETUNING_ARGS_NAME,
|
|
||||||
VALUE_HEAD_FILE_NAME
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
|
||||||
r"""
|
|
||||||
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
|
|
||||||
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
|
|
||||||
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
|
|
||||||
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
|
|
||||||
purposes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
|
||||||
r"""
|
|
||||||
Event called after logging the last logs.
|
|
||||||
"""
|
|
||||||
if "loss" not in state.log_history[-1]:
|
|
||||||
return
|
|
||||||
cur_time = time.time()
|
|
||||||
cur_steps = state.log_history[-1].get("step")
|
|
||||||
elapsed_time = cur_time - self.start_time
|
|
||||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
|
||||||
remaining_steps = state.max_steps - cur_steps
|
|
||||||
remaining_time = remaining_steps * avg_time_per_step
|
|
||||||
log_dict = {
|
|
||||||
"current_steps": cur_steps,
|
|
||||||
"total_steps": state.max_steps,
|
|
||||||
"loss": state.log_history[-1].get("loss", None),
|
|
||||||
"reward": state.log_history[-1].get("reward", None),
|
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
||||||
"epoch": state.log_history[-1].get("epoch", None),
|
|
||||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
||||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
||||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
||||||
}
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
|
|
||||||
f.write(json.dumps(log_dict) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
class PeftTrainer(Seq2SeqTrainer):
|
class PeftTrainer(Seq2SeqTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
1
src/llmtuner/tuner/ppo/__init__.py
Normal file
1
src/llmtuner/tuner/ppo/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.tuner.ppo.workflow import run_ppo
|
@ -2,77 +2,43 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerState
|
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer, LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
from .config import FinetuningArguments
|
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||||
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from .other import (
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
AverageMeter,
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
get_logger,
|
|
||||||
get_logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
|
||||||
if target == "reward": # save default head temporarily
|
|
||||||
valuehead_state_dict = model.v_head.state_dict()
|
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
|
||||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
|
||||||
model.v_head.load_state_dict({
|
|
||||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
|
||||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
|
||||||
model: AutoModelForCausalLMWithValueHead,
|
|
||||||
layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings
|
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
|
||||||
|
|
||||||
layer_norm_state_dict = {}
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
|
||||||
if layer_norm_params is not None:
|
|
||||||
param.data = layer_norm_params[name] # restore float32 weights
|
|
||||||
else:
|
|
||||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
|
||||||
param.data = param.data.to(torch.float16)
|
|
||||||
|
|
||||||
return model, layer_norm_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PPOTrainer.
|
Inherits PPOTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
training_args: Seq2SeqTrainingArguments,
|
training_args: Seq2SeqTrainingArguments,
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: FinetuningArguments,
|
||||||
callbacks: List[LogCallback],
|
callbacks: List[LogCallback],
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.log_callback = callbacks[0]
|
self.log_callback = callbacks[0]
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
|
self.control = TrainerControl()
|
||||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||||
|
|
||||||
def ppo_train(self, max_target_length: int) -> None:
|
def ppo_train(self, max_target_length: int) -> None:
|
||||||
@ -117,8 +83,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
reward_meter = AverageMeter()
|
reward_meter = AverageMeter()
|
||||||
|
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
|
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
||||||
|
|
||||||
for _ in range(self.config.gradient_accumulation_steps):
|
for _ in range(self.config.gradient_accumulation_steps):
|
||||||
|
|
||||||
@ -158,6 +125,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
if steps_trained == len_dataloader:
|
if steps_trained == len_dataloader:
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
@ -172,20 +142,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
print(logs)
|
print(logs)
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
self.state.log_history.append(logs)
|
self.state.log_history.append(logs)
|
||||||
self.log_callback.on_log(self.args, self.state, None)
|
self.log_callback.on_log(self.args, self.state, self.control)
|
||||||
loss_meter.reset()
|
loss_meter.reset()
|
||||||
reward_meter.reset()
|
reward_meter.reset()
|
||||||
|
|
||||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||||
|
|
||||||
|
if self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, torch.Tensor],
|
inputs: Dict[str, torch.Tensor],
|
||||||
length_sampler: Optional[Callable] = None,
|
length_sampler: Optional[Callable] = None,
|
||||||
return_prompt: Optional[bool] = True,
|
return_prompt: Optional[bool] = True,
|
||||||
**generation_kwargs,
|
**generation_kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
37
src/llmtuner/tuner/ppo/utils.py
Normal file
37
src/llmtuner/tuner/ppo/utils.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, List, Literal, Optional, Tuple
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
|
||||||
|
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||||
|
if target == "reward": # save default head temporarily
|
||||||
|
valuehead_state_dict = model.v_head.state_dict()
|
||||||
|
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||||
|
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||||
|
|
||||||
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
|
model.v_head.load_state_dict({
|
||||||
|
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||||
|
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def cast_layernorm_dtype(
|
||||||
|
model: AutoModelForCausalLMWithValueHead,
|
||||||
|
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||||
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||||
|
|
||||||
|
layer_norm_state_dict = {}
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
|
if layer_norm_params is not None:
|
||||||
|
param.data = layer_norm_params[name] # restore float32 weights
|
||||||
|
else:
|
||||||
|
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||||
|
param.data = param.data.to(torch.float16)
|
||||||
|
|
||||||
|
return model, layer_norm_state_dict
|
@ -1,36 +1,30 @@
|
|||||||
# coding=utf-8
|
# Inspired by:
|
||||||
# Implements parameter-efficient PPO training of fine-tuned models.
|
|
||||||
# This code is inspired by:
|
|
||||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from transformers.optimization import get_scheduler
|
|
||||||
from trl import PPOConfig
|
from trl import PPOConfig
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from torch.optim import AdamW
|
||||||
from utils import (
|
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||||
PPOPeftTrainer,
|
from transformers.optimization import get_scheduler
|
||||||
LogCallback,
|
|
||||||
load_pretrained,
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
prepare_args,
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
prepare_data,
|
from llmtuner.extras.ploting import plot_loss
|
||||||
preprocess_data,
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
plot_loss
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
)
|
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def run_ppo(
|
||||||
|
model_args: ModelArguments,
|
||||||
# Prepare pretrained model and dataset
|
data_args: DataArguments,
|
||||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
|
training_args: Seq2SeqTrainingArguments,
|
||||||
dataset = prepare_data(model_args, data_args)
|
finetuning_args: FinetuningArguments
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
):
|
||||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
dataset = get_dataset(model_args, data_args)
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||||
tokenizer=tokenizer,
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||||
label_pad_token_id=tokenizer.pad_token_id
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
|
||||||
)
|
|
||||||
|
|
||||||
ppo_config = PPOConfig(
|
ppo_config = PPOConfig(
|
||||||
model_name=model_args.model_name_or_path,
|
model_name=model_args.model_name_or_path,
|
||||||
@ -72,12 +66,3 @@ def main():
|
|||||||
ppo_trainer.save_state() # must be after save_model
|
ppo_trainer.save_state() # must be after save_model
|
||||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
|
||||||
# For xla_spawn (TPUs)
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
1
src/llmtuner/tuner/pt/__init__.py
Normal file
1
src/llmtuner/tuner/pt/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.tuner.pt.workflow import run_pt
|
@ -1,31 +1,28 @@
|
|||||||
# coding=utf-8
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||||
# Implements several parameter-efficient pre-training method.
|
|
||||||
# This code is inspired by
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
|
||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from typing import Optional, List
|
||||||
from utils.other import IGNORE_INDEX
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
from utils import (
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
PeftTrainer,
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
LogCallback,
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
load_pretrained,
|
from llmtuner.extras.ploting import plot_loss
|
||||||
prepare_args,
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
prepare_data,
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
preprocess_data,
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
plot_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def run_pt(
|
||||||
|
model_args: ModelArguments,
|
||||||
# Prepare pretrained model and dataset
|
data_args: DataArguments,
|
||||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
|
training_args: Seq2SeqTrainingArguments,
|
||||||
dataset = prepare_data(model_args, data_args)
|
finetuning_args: FinetuningArguments,
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
):
|
||||||
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||||
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
@ -48,7 +45,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=[LogCallback()],
|
callbacks=callbacks,
|
||||||
**trainer_kwargs
|
**trainer_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,21 +62,12 @@ def main():
|
|||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
perplexity = float("inf")
|
perplexity = float("inf")
|
||||||
|
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
|
||||||
# For xla_spawn (TPUs)
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
1
src/llmtuner/tuner/rm/__init__.py
Normal file
1
src/llmtuner/tuner/rm/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.tuner.rm.workflow import run_rm
|
19
src/llmtuner/tuner/rm/collator.py
Normal file
19
src/llmtuner/tuner/rm/collator.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Any, Dict, Sequence
|
||||||
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
|
||||||
|
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||||
|
r"""
|
||||||
|
Data collator for pairwise data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
|
the last n examples represent rejected examples.
|
||||||
|
"""
|
||||||
|
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||||
|
return super().__call__(features)
|
7
src/llmtuner/tuner/rm/metric.py
Normal file
7
src/llmtuner/tuner/rm/metric.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
|
preds, _ = eval_preds
|
||||||
|
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
38
src/llmtuner/tuner/rm/trainer.py
Normal file
38
src/llmtuner/tuner/rm/trainer.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class PairwisePeftTrainer(PeftTrainer):
|
||||||
|
r"""
|
||||||
|
Inherits PeftTrainer to compute pairwise loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.can_return_loss = True # override property to return eval_loss
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
inputs: Dict[str, torch.Tensor],
|
||||||
|
return_outputs: Optional[bool] = False
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||||
|
r"""
|
||||||
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
|
|
||||||
|
We use score on the EOS token to represent reward of the whole sentence.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||||
|
|
||||||
|
Note that the first element will be removed from the output tuple.
|
||||||
|
|
||||||
|
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||||
|
"""
|
||||||
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
|
_, _, values = model(**inputs)
|
||||||
|
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||||
|
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||||
|
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
@ -1,30 +1,28 @@
|
|||||||
# coding=utf-8
|
# Inspired by:
|
||||||
# Implements parameter-efficient training of reward models.
|
|
||||||
# This code is inspired by:
|
|
||||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
|
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from utils import (
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
PairwiseDataCollatorWithPadding,
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
PairwisePeftTrainer,
|
from llmtuner.extras.ploting import plot_loss
|
||||||
LogCallback,
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
load_pretrained,
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
prepare_args,
|
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||||
prepare_data,
|
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
preprocess_data,
|
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||||
compute_accuracy,
|
|
||||||
plot_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def run_rm(
|
||||||
|
model_args: ModelArguments,
|
||||||
# Prepare pretrained model and dataset
|
data_args: DataArguments,
|
||||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
|
training_args: Seq2SeqTrainingArguments,
|
||||||
dataset = prepare_data(model_args, data_args)
|
finetuning_args: FinetuningArguments
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
):
|
||||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||||
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||||
|
|
||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
@ -66,12 +64,3 @@ def main():
|
|||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
|
||||||
# For xla_spawn (TPUs)
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
1
src/llmtuner/tuner/sft/__init__.py
Normal file
1
src/llmtuner/tuner/sft/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.tuner.sft.workflow import run_sft
|
51
src/llmtuner/tuner/sft/metric.py
Normal file
51
src/llmtuner/tuner/sft/metric.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Sequence, Tuple, Union
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from rouge_chinese import Rouge
|
||||||
|
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComputeMetrics:
|
||||||
|
r"""
|
||||||
|
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||||
|
r"""
|
||||||
|
Uses the model predictions to compute metrics.
|
||||||
|
"""
|
||||||
|
preds, labels = eval_preds
|
||||||
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
|
for pred, label in zip(decoded_preds, decoded_labels):
|
||||||
|
hypothesis = list(jieba.cut(pred))
|
||||||
|
reference = list(jieba.cut(label))
|
||||||
|
|
||||||
|
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
|
||||||
|
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
||||||
|
else:
|
||||||
|
rouge = Rouge()
|
||||||
|
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
||||||
|
result = scores[0]
|
||||||
|
|
||||||
|
for k, v in result.items():
|
||||||
|
score_dict[k].append(round(v["f"] * 100, 4))
|
||||||
|
|
||||||
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
|
|
||||||
|
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
@ -3,65 +3,17 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from dataclasses import dataclass
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
||||||
|
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
import jieba
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from rouge_chinese import Rouge
|
from llmtuner.extras.logging import get_logger
|
||||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer
|
|
||||||
|
|
||||||
from .other import get_logger, IGNORE_INDEX
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ComputeMetrics:
|
|
||||||
r"""
|
|
||||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokenizer: PreTrainedTokenizer
|
|
||||||
|
|
||||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
|
||||||
r"""
|
|
||||||
Uses the model predictions to compute metrics.
|
|
||||||
"""
|
|
||||||
preds, labels = eval_preds
|
|
||||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
|
||||||
|
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
|
|
||||||
for pred, label in zip(decoded_preds, decoded_labels):
|
|
||||||
hypothesis = list(jieba.cut(pred))
|
|
||||||
reference = list(jieba.cut(label))
|
|
||||||
|
|
||||||
if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
|
|
||||||
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
|
||||||
else:
|
|
||||||
rouge = Rouge()
|
|
||||||
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
|
||||||
result = scores[0]
|
|
||||||
|
|
||||||
for k, v in result.items():
|
|
||||||
score_dict[k].append(round(v["f"] * 100, 4))
|
|
||||||
|
|
||||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
|
||||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
|
||||||
|
|
||||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||||
@ -80,7 +32,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||||
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
|
||||||
|
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
|
||||||
|
else:
|
||||||
|
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
|
||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
@ -89,8 +44,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
return (loss, generated_tokens, labels)
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
predict_results: PredictionOutput
|
predict_results: PredictionOutput
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
@ -1,31 +1,29 @@
|
|||||||
# coding=utf-8
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||||
# Implements several parameter-efficient supervised fine-tuning method.
|
|
||||||
# This code is inspired by
|
from typing import Optional, List
|
||||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||||
|
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||||
|
|
||||||
|
|
||||||
from transformers import DataCollatorForSeq2Seq
|
def run_sft(
|
||||||
from utils.other import IGNORE_INDEX
|
model_args: ModelArguments,
|
||||||
from utils import (
|
data_args: DataArguments,
|
||||||
Seq2SeqPeftTrainer,
|
training_args: Seq2SeqTrainingArguments,
|
||||||
ComputeMetrics,
|
finetuning_args: FinetuningArguments,
|
||||||
LogCallback,
|
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||||
load_pretrained,
|
):
|
||||||
prepare_args,
|
dataset = get_dataset(model_args, data_args)
|
||||||
prepare_data,
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
preprocess_data,
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||||
get_logits_processor,
|
|
||||||
plot_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
# Prepare pretrained model and dataset
|
|
||||||
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
|
|
||||||
dataset = prepare_data(model_args, data_args)
|
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
|
||||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
@ -54,7 +52,7 @@ def main():
|
|||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=[LogCallback()],
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**trainer_kwargs
|
**trainer_kwargs
|
||||||
)
|
)
|
||||||
@ -94,12 +92,3 @@ def main():
|
|||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(predict_results)
|
trainer.save_predictions(predict_results)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
|
||||||
# For xla_spawn (TPUs)
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
23
src/train_bash.py
Normal file
23
src/train_bash.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
|
||||||
|
|
||||||
|
if general_args.stage == "pt":
|
||||||
|
run_pt(model_args, data_args, training_args, finetuning_args)
|
||||||
|
elif general_args.stage == "sft":
|
||||||
|
run_sft(model_args, data_args, training_args, finetuning_args)
|
||||||
|
elif general_args.stage == "rm":
|
||||||
|
run_rm(model_args, data_args, training_args, finetuning_args)
|
||||||
|
elif general_args.stage == "ppo":
|
||||||
|
run_ppo(model_args, data_args, training_args, finetuning_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _mp_fn(index):
|
||||||
|
# For xla_spawn (TPUs)
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,17 +0,0 @@
|
|||||||
from .common import (
|
|
||||||
load_pretrained,
|
|
||||||
prepare_args,
|
|
||||||
prepare_infer_args,
|
|
||||||
prepare_data,
|
|
||||||
preprocess_data
|
|
||||||
)
|
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer, LogCallback
|
|
||||||
|
|
||||||
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
|
||||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
|
|
||||||
from .ppo import PPOPeftTrainer
|
|
||||||
|
|
||||||
from .template import Template
|
|
||||||
|
|
||||||
from .other import get_logits_processor, plot_loss
|
|
@ -1,619 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
import hashlib
|
|
||||||
from itertools import chain
|
|
||||||
from typing import List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from transformers import (
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
HfArgumentParser,
|
|
||||||
Seq2SeqTrainingArguments,
|
|
||||||
BitsAndBytesConfig
|
|
||||||
)
|
|
||||||
from transformers.utils import check_min_version
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
|
||||||
|
|
||||||
from peft import (
|
|
||||||
PeftModel,
|
|
||||||
TaskType,
|
|
||||||
LoraConfig,
|
|
||||||
get_peft_model
|
|
||||||
)
|
|
||||||
|
|
||||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
|
||||||
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
from .config import (
|
|
||||||
ModelArguments,
|
|
||||||
DataTrainingArguments,
|
|
||||||
FinetuningArguments,
|
|
||||||
GeneratingArguments
|
|
||||||
)
|
|
||||||
|
|
||||||
from .template import Template
|
|
||||||
|
|
||||||
from .other import (
|
|
||||||
get_logger,
|
|
||||||
load_trainable_params,
|
|
||||||
load_valuehead_params,
|
|
||||||
print_trainable_params,
|
|
||||||
prepare_model_for_training,
|
|
||||||
IGNORE_INDEX
|
|
||||||
)
|
|
||||||
|
|
||||||
check_min_version("4.29.1")
|
|
||||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
|
||||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
|
||||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
|
||||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_adapter(
|
|
||||||
model: PreTrainedModel,
|
|
||||||
model_args: ModelArguments,
|
|
||||||
finetuning_args: FinetuningArguments,
|
|
||||||
is_trainable: bool,
|
|
||||||
is_mergeable: bool
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
r"""
|
|
||||||
Initializes the adapters.
|
|
||||||
|
|
||||||
Support full-parameter, freeze and LoRA training.
|
|
||||||
|
|
||||||
Note that the trainable parameters must be cast to float32.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
|
||||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
|
||||||
logger.info("Fine-tuning method: Full")
|
|
||||||
model = model.float()
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze":
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
|
||||||
param.requires_grad_(False)
|
|
||||||
else:
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
|
||||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "lora":
|
|
||||||
logger.info("Fine-tuning method: LoRA")
|
|
||||||
latest_checkpoint = None
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
|
||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
|
||||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
|
||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
|
||||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
|
||||||
|
|
||||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
|
||||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
|
||||||
else:
|
|
||||||
checkpoints_to_merge = model_args.checkpoint_dir
|
|
||||||
|
|
||||||
for checkpoint in checkpoints_to_merge:
|
|
||||||
model = PeftModel.from_pretrained(model, checkpoint)
|
|
||||||
model = model.merge_and_unload()
|
|
||||||
|
|
||||||
if len(checkpoints_to_merge) > 0:
|
|
||||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
|
||||||
|
|
||||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
|
||||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
|
||||||
|
|
||||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
|
||||||
lora_config = LoraConfig(
|
|
||||||
task_type=TaskType.CAUSAL_LM,
|
|
||||||
inference_mode=False,
|
|
||||||
r=finetuning_args.lora_rank,
|
|
||||||
lora_alpha=finetuning_args.lora_alpha,
|
|
||||||
lora_dropout=finetuning_args.lora_dropout,
|
|
||||||
target_modules=finetuning_args.lora_target
|
|
||||||
)
|
|
||||||
model = get_peft_model(model, lora_config)
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained(
|
|
||||||
model_args: ModelArguments,
|
|
||||||
finetuning_args: FinetuningArguments,
|
|
||||||
is_trainable: Optional[bool] = False,
|
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
||||||
r"""
|
|
||||||
Loads pretrained model and tokenizer.
|
|
||||||
|
|
||||||
Support both training and inference.
|
|
||||||
"""
|
|
||||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
|
||||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
|
||||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
|
||||||
|
|
||||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
|
||||||
"RM and PPO training can only be performed with the LoRA method."
|
|
||||||
|
|
||||||
config_kwargs = {
|
|
||||||
"trust_remote_code": True,
|
|
||||||
"cache_dir": model_args.cache_dir,
|
|
||||||
"revision": model_args.model_revision,
|
|
||||||
"use_auth_token": True if model_args.use_auth_token else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
|
||||||
padding_side=model_args.padding_side,
|
|
||||||
**config_kwargs
|
|
||||||
)
|
|
||||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
|
||||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
|
||||||
is_mergeable = True
|
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
|
||||||
if model_args.quantization_bit is not None:
|
|
||||||
if model_args.quantization_bit == 8:
|
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
||||||
config_kwargs["load_in_8bit"] = True
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
||||||
load_in_8bit=True,
|
|
||||||
llm_int8_threshold=6.0
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
||||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
|
||||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
|
||||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
|
||||||
config_kwargs["load_in_4bit"] = True
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
|
||||||
)
|
|
||||||
|
|
||||||
is_mergeable = False
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
||||||
|
|
||||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
|
||||||
config_kwargs["device_map"] = "auto"
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
|
||||||
else:
|
|
||||||
model_to_load = model_args.model_name_or_path
|
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_to_load,
|
|
||||||
config=config,
|
|
||||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
**config_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
|
||||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
|
||||||
config.__class__.register_for_auto_class()
|
|
||||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
|
||||||
tokenizer.__class__.register_for_auto_class()
|
|
||||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
|
||||||
model.__class__.register_for_auto_class()
|
|
||||||
|
|
||||||
# Initialize adapters
|
|
||||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
|
||||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
|
||||||
|
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
|
||||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
|
||||||
model.v_head.load_state_dict({
|
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
|
||||||
"summary.bias": getattr(model, "reward_head_bias")
|
|
||||||
})
|
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
|
||||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
|
||||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
|
||||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
|
||||||
|
|
||||||
if not is_trainable:
|
|
||||||
model.requires_grad_(False) # fix all model params
|
|
||||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
|
||||||
|
|
||||||
print_trainable_params(model)
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_args(
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
|
||||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
|
||||||
|
|
||||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
|
||||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
if training_args.should_log:
|
|
||||||
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
|
|
||||||
log_level = training_args.get_process_log_level()
|
|
||||||
datasets.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.enable_default_handler()
|
|
||||||
transformers.utils.logging.enable_explicit_format()
|
|
||||||
|
|
||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
|
||||||
data_args.init_for_training()
|
|
||||||
|
|
||||||
assert stage == "sft" or (not training_args.predict_with_generate), \
|
|
||||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
|
||||||
|
|
||||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
|
||||||
"`predict_with_generate` cannot be set as True while training."
|
|
||||||
|
|
||||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
|
||||||
"Please enable `predict_with_generate` to save model predictions."
|
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
|
||||||
"Quantization is only compatible with the LoRA method."
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
|
||||||
if finetuning_args.finetuning_type != "lora":
|
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
|
||||||
else:
|
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
|
||||||
"Quantized model only accepts a single checkpoint."
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16):
|
|
||||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
|
||||||
|
|
||||||
if data_args.prompt_template == "default":
|
|
||||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
|
||||||
|
|
||||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
|
||||||
training_args.ddp_find_unused_parameters = False
|
|
||||||
|
|
||||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
|
||||||
if training_args.fp16:
|
|
||||||
model_args.compute_dtype = torch.float16
|
|
||||||
elif training_args.bf16:
|
|
||||||
model_args.compute_dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
model_args.compute_dtype = torch.float32
|
|
||||||
|
|
||||||
# Log on each process the small summary:
|
|
||||||
logger.info(
|
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
|
||||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
|
||||||
)
|
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
|
||||||
|
|
||||||
# Set seed before initializing model.
|
|
||||||
transformers.set_seed(training_args.seed)
|
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
|
|
||||||
|
|
||||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
|
|
||||||
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
|
||||||
"Quantization is only compatible with the LoRA method."
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
|
||||||
if finetuning_args.finetuning_type != "lora":
|
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
|
||||||
else:
|
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
|
||||||
"Quantized model only accepts a single checkpoint."
|
|
||||||
|
|
||||||
if data_args.prompt_template == "default":
|
|
||||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_data(
|
|
||||||
model_args: ModelArguments,
|
|
||||||
data_args: DataTrainingArguments
|
|
||||||
) -> Dataset:
|
|
||||||
|
|
||||||
def checksum(file_path, hash):
|
|
||||||
with open(file_path, "rb") as datafile:
|
|
||||||
binary_data = datafile.read()
|
|
||||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
|
||||||
if sha1 != hash:
|
|
||||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
|
||||||
|
|
||||||
ext2type = {
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
max_samples = data_args.max_samples
|
|
||||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
|
||||||
|
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
|
||||||
|
|
||||||
if dataset_attr.load_from == "hf_hub":
|
|
||||||
data_path = dataset_attr.dataset_name
|
|
||||||
data_files = None
|
|
||||||
elif dataset_attr.load_from == "script":
|
|
||||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
|
||||||
data_files = None
|
|
||||||
elif dataset_attr.load_from == "file":
|
|
||||||
data_path = None
|
|
||||||
data_files: List[str] = []
|
|
||||||
|
|
||||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
||||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
|
||||||
|
|
||||||
if data_path is None:
|
|
||||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
|
||||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
|
||||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
|
||||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
|
||||||
else:
|
|
||||||
raise ValueError("File not found.")
|
|
||||||
|
|
||||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
|
||||||
|
|
||||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
|
||||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
|
||||||
else:
|
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
raw_datasets = load_dataset(
|
|
||||||
data_path,
|
|
||||||
data_files=data_files,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None
|
|
||||||
)
|
|
||||||
dataset = raw_datasets[data_args.split]
|
|
||||||
|
|
||||||
if max_samples is not None:
|
|
||||||
max_samples_temp = min(len(dataset), max_samples)
|
|
||||||
dataset = dataset.select(range(max_samples_temp))
|
|
||||||
|
|
||||||
dummy_data = [None] * len(dataset)
|
|
||||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
|
||||||
for column_name, target_name in [
|
|
||||||
("prompt_column", "prompt"),
|
|
||||||
("query_column", "query"),
|
|
||||||
("response_column", "response"),
|
|
||||||
("history_column", "history")
|
|
||||||
]: # every dataset will have 4 columns same as each other
|
|
||||||
if getattr(dataset_attr, column_name) != target_name:
|
|
||||||
if getattr(dataset_attr, column_name):
|
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
|
||||||
else: # None or empty string
|
|
||||||
dataset = dataset.add_column(target_name, dummy_data)
|
|
||||||
dataset = dataset.add_column("prefix", prefix_data)
|
|
||||||
all_datasets.append(dataset)
|
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
|
||||||
all_datasets = all_datasets[0]
|
|
||||||
else:
|
|
||||||
all_datasets = concatenate_datasets(all_datasets)
|
|
||||||
|
|
||||||
return all_datasets
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_data(
|
|
||||||
dataset: Dataset,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
|
||||||
data_args: DataTrainingArguments,
|
|
||||||
training_args: Seq2SeqTrainingArguments,
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
|
||||||
) -> Dataset:
|
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
|
||||||
prompt_template = Template(data_args.prompt_template)
|
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
|
||||||
def get_dialog(examples):
|
|
||||||
for i in range(len(examples["prompt"])):
|
|
||||||
if examples["prompt"][i] and examples["response"][i]:
|
|
||||||
query, answer = examples["prompt"][i], examples["response"][i]
|
|
||||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
|
||||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
|
||||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
|
||||||
yield dialog
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples):
|
|
||||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
|
||||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
|
||||||
concatenated_ids = list(chain(*text_ids))
|
|
||||||
total_length = len(concatenated_ids)
|
|
||||||
block_size = data_args.max_source_length - 1
|
|
||||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
|
||||||
total_length = (total_length // block_size) * block_size
|
|
||||||
# split by chunks of max_source_length
|
|
||||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
|
||||||
for i in range(0, total_length, block_size)]
|
|
||||||
return {
|
|
||||||
"input_ids": result,
|
|
||||||
"labels": result.copy()
|
|
||||||
}
|
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples):
|
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
|
||||||
# for input with history, we build multiple input-label pairs just like:
|
|
||||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
|
||||||
max_length = data_args.max_source_length + data_args.max_target_length
|
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
|
||||||
input_ids, labels = [], []
|
|
||||||
|
|
||||||
for i in range(len(dialog) // 2):
|
|
||||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
|
||||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
|
||||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
|
||||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
|
||||||
|
|
||||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
|
||||||
break
|
|
||||||
|
|
||||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
|
||||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
|
||||||
model_inputs["labels"].append(labels)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples):
|
|
||||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
|
||||||
model_inputs = {"input_ids": [], "labels": []}
|
|
||||||
|
|
||||||
for dialog in get_dialog(examples):
|
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
|
||||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
|
||||||
if len(target_ids) > data_args.max_target_length:
|
|
||||||
target_ids = target_ids[:data_args.max_target_length]
|
|
||||||
|
|
||||||
model_inputs["input_ids"].append(source_ids)
|
|
||||||
model_inputs["labels"].append(target_ids)
|
|
||||||
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples):
|
|
||||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
|
||||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
|
||||||
for dialog in get_dialog(examples):
|
|
||||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
|
||||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
|
||||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
|
||||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
|
||||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
|
||||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
|
||||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
|
||||||
|
|
||||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
|
||||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
model_inputs["accept_ids"].append(accept_ids)
|
|
||||||
model_inputs["reject_ids"].append(reject_ids)
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def print_supervised_dataset_example(example):
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
|
||||||
print("labels:\n{}".format(
|
|
||||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
|
||||||
skip_special_tokens=False)
|
|
||||||
))
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example):
|
|
||||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
|
||||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
|
||||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
|
||||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example):
|
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
|
||||||
|
|
||||||
if stage == "pt":
|
|
||||||
preprocess_function = preprocess_pretrain_dataset
|
|
||||||
elif stage == "sft":
|
|
||||||
preprocess_function = preprocess_unsupervised_dataset \
|
|
||||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
|
||||||
elif stage == "rm":
|
|
||||||
preprocess_function = preprocess_pairwise_dataset
|
|
||||||
elif stage == "ppo":
|
|
||||||
preprocess_function = preprocess_unsupervised_dataset
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
|
||||||
dataset = dataset.map(
|
|
||||||
preprocess_function,
|
|
||||||
batched=True,
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
remove_columns=column_names,
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
desc="Running tokenizer on dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
if stage == "pt":
|
|
||||||
print_unsupervised_dataset_example(dataset[0])
|
|
||||||
elif stage == "sft":
|
|
||||||
print_supervised_dataset_example(dataset[0])
|
|
||||||
elif stage == "rm":
|
|
||||||
print_pairwise_dataset_example(dataset[0])
|
|
||||||
elif stage == "ppo":
|
|
||||||
print_unsupervised_dataset_example(dataset[0])
|
|
||||||
|
|
||||||
return dataset
|
|
@ -1,312 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
from dataclasses import asdict, dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetAttr:
|
|
||||||
|
|
||||||
load_from: str
|
|
||||||
dataset_name: Optional[str] = None
|
|
||||||
dataset_sha1: Optional[str] = None
|
|
||||||
source_prefix: Optional[str] = None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return self.dataset_name
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.prompt_column = "instruction"
|
|
||||||
self.query_column = "input"
|
|
||||||
self.response_column = "output"
|
|
||||||
self.history_column = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
|
||||||
"""
|
|
||||||
model_name_or_path: str = field(
|
|
||||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
|
||||||
)
|
|
||||||
cache_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
|
||||||
)
|
|
||||||
use_fast_tokenizer: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
|
||||||
)
|
|
||||||
use_auth_token: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
|
||||||
)
|
|
||||||
model_revision: Optional[str] = field(
|
|
||||||
default="main",
|
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
|
||||||
)
|
|
||||||
padding_side: Optional[Literal["left", "right"]] = field(
|
|
||||||
default="left",
|
|
||||||
metadata={"help": "The side on which the model should have padding applied."}
|
|
||||||
)
|
|
||||||
quantization_bit: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The number of bits to quantize the model."}
|
|
||||||
)
|
|
||||||
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
|
||||||
default="nf4",
|
|
||||||
metadata={"help": "Quantization data type to use in int4 training."}
|
|
||||||
)
|
|
||||||
double_quantization: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
|
||||||
)
|
|
||||||
compute_dtype: Optional[torch.dtype] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
|
||||||
)
|
|
||||||
checkpoint_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
|
||||||
)
|
|
||||||
reward_model: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
|
||||||
)
|
|
||||||
resume_lora_training: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
|
||||||
|
|
||||||
if self.quantization_bit is not None:
|
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataTrainingArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
|
||||||
"""
|
|
||||||
dataset: Optional[str] = field(
|
|
||||||
default="alpaca_zh",
|
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
|
||||||
)
|
|
||||||
dataset_dir: Optional[str] = field(
|
|
||||||
default="data",
|
|
||||||
metadata={"help": "The name of the folder containing datasets."}
|
|
||||||
)
|
|
||||||
split: Optional[str] = field(
|
|
||||||
default="train",
|
|
||||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
|
||||||
)
|
|
||||||
overwrite_cache: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
|
||||||
)
|
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
|
||||||
)
|
|
||||||
max_source_length: Optional[int] = field(
|
|
||||||
default=512,
|
|
||||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
|
||||||
)
|
|
||||||
max_target_length: Optional[int] = field(
|
|
||||||
default=512,
|
|
||||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
|
||||||
)
|
|
||||||
max_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
|
||||||
)
|
|
||||||
eval_num_beams: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
|
||||||
)
|
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
|
||||||
)
|
|
||||||
source_prefix: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
|
||||||
)
|
|
||||||
dev_ratio: Optional[float] = field(
|
|
||||||
default=0,
|
|
||||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
|
||||||
)
|
|
||||||
prompt_template: Optional[str] = field(
|
|
||||||
default="default",
|
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
|
||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
|
||||||
dataset_info = json.load(f)
|
|
||||||
|
|
||||||
if self.source_prefix is not None:
|
|
||||||
prefix_list = self.source_prefix.split("|")
|
|
||||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
|
||||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
|
||||||
else:
|
|
||||||
prefix_list = [None] * len(dataset_names)
|
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
|
||||||
for i, name in enumerate(dataset_names):
|
|
||||||
if name not in dataset_info:
|
|
||||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
|
||||||
|
|
||||||
if "hf_hub_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
|
||||||
elif "script_url" in dataset_info[name]:
|
|
||||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
|
||||||
else:
|
|
||||||
dataset_attr = DatasetAttr(
|
|
||||||
"file",
|
|
||||||
dataset_name=dataset_info[name]["file_name"],
|
|
||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_attr.source_prefix = prefix_list[i]
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
|
||||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
|
||||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
|
||||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
|
||||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
|
||||||
|
|
||||||
self.dataset_list.append(dataset_attr)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FinetuningArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
|
||||||
"""
|
|
||||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
|
||||||
default="lora",
|
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
|
||||||
)
|
|
||||||
num_hidden_layers: Optional[int] = field(
|
|
||||||
default=32,
|
|
||||||
metadata={"help": "Number of decoder blocks in the model. \
|
|
||||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
|
||||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
|
||||||
Falcon choices: [\"32\", \"60\"], \
|
|
||||||
Baichuan choices: [\"32\"]"}
|
|
||||||
)
|
|
||||||
num_layer_trainable: Optional[int] = field(
|
|
||||||
default=3,
|
|
||||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
|
||||||
)
|
|
||||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
|
||||||
default="mlp",
|
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
|
||||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
|
||||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
|
||||||
)
|
|
||||||
lora_rank: Optional[int] = field(
|
|
||||||
default=8,
|
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
|
||||||
)
|
|
||||||
lora_alpha: Optional[float] = field(
|
|
||||||
default=32.0,
|
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
|
||||||
)
|
|
||||||
lora_dropout: Optional[float] = field(
|
|
||||||
default=0.1,
|
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
|
||||||
)
|
|
||||||
lora_target: Optional[str] = field(
|
|
||||||
default="q_proj,v_proj",
|
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
|
||||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
|
||||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
|
||||||
|
|
||||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
|
||||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
|
||||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
|
||||||
|
|
||||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
|
||||||
|
|
||||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
|
||||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(json_string)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_from_json(cls, json_path: str):
|
|
||||||
"""Creates an instance from the content of `json_path`."""
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
text = f.read()
|
|
||||||
return cls(**json.loads(text))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GeneratingArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to specify the decoding parameters.
|
|
||||||
"""
|
|
||||||
do_sample: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
|
||||||
)
|
|
||||||
temperature: Optional[float] = field(
|
|
||||||
default=0.95,
|
|
||||||
metadata={"help": "The value used to modulate the next token probabilities."}
|
|
||||||
)
|
|
||||||
top_p: Optional[float] = field(
|
|
||||||
default=0.7,
|
|
||||||
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
|
||||||
)
|
|
||||||
top_k: Optional[int] = field(
|
|
||||||
default=50,
|
|
||||||
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
|
||||||
)
|
|
||||||
num_beams: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
|
||||||
)
|
|
||||||
max_length: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
|
||||||
)
|
|
||||||
max_new_tokens: Optional[int] = field(
|
|
||||||
default=512,
|
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
|
||||||
)
|
|
||||||
repetition_penalty: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
|
||||||
)
|
|
||||||
length_penalty: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
args = asdict(self)
|
|
||||||
if args.get("max_new_tokens", None):
|
|
||||||
args.pop("max_length", None)
|
|
||||||
return args
|
|
@ -1,197 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
|
||||||
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
|
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
|
||||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
|
||||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
|
||||||
return logging.getLogger(name)
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
level=logging.INFO,
|
|
||||||
handlers=[logging.StreamHandler(sys.stdout)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
|
||||||
r"""
|
|
||||||
Computes and stores the average and current value.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.val = 0
|
|
||||||
self.avg = 0
|
|
||||||
self.sum = 0
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
def update(self, val, n=1):
|
|
||||||
self.val = val
|
|
||||||
self.sum += val * n
|
|
||||||
self.count += n
|
|
||||||
self.avg = self.sum / self.count
|
|
||||||
|
|
||||||
|
|
||||||
# Avoid runtime error in model.generate(do_sample=True).
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
||||||
scores.zero_()
|
|
||||||
scores[..., 0] = 1.0
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
|
||||||
logits_processor = LogitsProcessorList()
|
|
||||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
|
||||||
return logits_processor
|
|
||||||
|
|
||||||
|
|
||||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
|
||||||
def prepare_model_for_training(
|
|
||||||
model: PreTrainedModel,
|
|
||||||
finetuning_type: str,
|
|
||||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
|
||||||
layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
|
||||||
param.data = param.data.to(torch.float32)
|
|
||||||
|
|
||||||
if use_gradient_checkpointing:
|
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
else:
|
|
||||||
def make_inputs_require_grad(module, input, output):
|
|
||||||
output.requires_grad_(True)
|
|
||||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
|
||||||
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
|
||||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
|
||||||
input_dtype = output_embedding_layer.weight.dtype
|
|
||||||
|
|
||||||
class CastOutputToFloat(torch.nn.Sequential):
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
|
||||||
|
|
||||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
|
||||||
trainable_params, all_param = 0, 0
|
|
||||||
for param in model.parameters():
|
|
||||||
num_params = param.numel()
|
|
||||||
# if using DS Zero 3 and the weights are initialized empty
|
|
||||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
|
||||||
num_params = param.ds_numel
|
|
||||||
all_param += num_params
|
|
||||||
if param.requires_grad:
|
|
||||||
trainable_params += num_params
|
|
||||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
|
||||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
|
||||||
|
|
||||||
|
|
||||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
filtered_state_dict = {}
|
|
||||||
|
|
||||||
for k, v in model.named_parameters():
|
|
||||||
if v.requires_grad:
|
|
||||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
|
||||||
|
|
||||||
return filtered_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
||||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
||||||
if os.path.exists(weights_file):
|
|
||||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
|
||||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
|
||||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
|
||||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
|
||||||
else:
|
|
||||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
||||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
|
||||||
if not os.path.exists(valuehead_file):
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
|
||||||
return False
|
|
||||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
|
||||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
|
||||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
|
||||||
r"""
|
|
||||||
EMA implementation according to TensorBoard.
|
|
||||||
"""
|
|
||||||
last = scalars[0]
|
|
||||||
smoothed = list()
|
|
||||||
for next_val in scalars:
|
|
||||||
smoothed_val = last * weight + (1 - weight) * next_val
|
|
||||||
smoothed.append(smoothed_val)
|
|
||||||
last = smoothed_val
|
|
||||||
return smoothed
|
|
||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
steps, metrics = [], []
|
|
||||||
for i in range(len(data["log_history"])):
|
|
||||||
if key in data["log_history"][i]:
|
|
||||||
steps.append(data["log_history"][i]["step"])
|
|
||||||
metrics.append(data["log_history"][i][key])
|
|
||||||
|
|
||||||
if len(metrics) == 0:
|
|
||||||
logger.warning(f"No metric {key} to plot.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
plt.figure()
|
|
||||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
|
||||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
|
||||||
plt.title("training {} of {}".format(key, save_dictionary))
|
|
||||||
plt.xlabel("step")
|
|
||||||
plt.ylabel(key)
|
|
||||||
plt.legend()
|
|
||||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
|
||||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
|
@ -1,60 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from typing import Dict, Sequence, Tuple, Union
|
|
||||||
|
|
||||||
from transformers import DataCollatorWithPadding
|
|
||||||
|
|
||||||
from .peft_trainer import PeftTrainer
|
|
||||||
|
|
||||||
from .other import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
|
||||||
preds, _ = eval_preds
|
|
||||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
|
||||||
|
|
||||||
|
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
|
||||||
r"""
|
|
||||||
Data collator for pairwise data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
|
||||||
r"""
|
|
||||||
Pads batched data to the longest sequence in the batch.
|
|
||||||
|
|
||||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
|
||||||
the last n examples represent rejected examples.
|
|
||||||
"""
|
|
||||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
|
||||||
return super().__call__(features)
|
|
||||||
|
|
||||||
|
|
||||||
class PairwisePeftTrainer(PeftTrainer):
|
|
||||||
r"""
|
|
||||||
Inherits PeftTrainer to compute pairwise loss.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.can_return_loss = True # override property to return eval_loss
|
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
|
||||||
r"""
|
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
|
||||||
|
|
||||||
We use score on the EOS token to represent reward of the whole sentence.
|
|
||||||
|
|
||||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
|
||||||
|
|
||||||
Note that the first element will be removed from the output tuple.
|
|
||||||
|
|
||||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
|
||||||
"""
|
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
|
||||||
_, _, values = model(**inputs)
|
|
||||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
|
||||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
|
||||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
|
@ -2,83 +2,26 @@
|
|||||||
# Implements user interface in browser for fine-tuned models.
|
# Implements user interface in browser for fine-tuned models.
|
||||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
|
|
||||||
import mdtex2html
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from utils import (
|
|
||||||
Template,
|
|
||||||
load_pretrained,
|
|
||||||
prepare_infer_args,
|
|
||||||
get_logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
||||||
|
|
||||||
|
|
||||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
model_args, data_args, finetuning_args, generating_args = get_infer_args()
|
||||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = Template(data_args.prompt_template)
|
||||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||||
|
|
||||||
|
|
||||||
def postprocess(self, y):
|
|
||||||
r"""
|
|
||||||
Overrides Chatbot.postprocess
|
|
||||||
"""
|
|
||||||
if y is None:
|
|
||||||
return []
|
|
||||||
for i, (message, response) in enumerate(y):
|
|
||||||
y[i] = (
|
|
||||||
None if message is None else mdtex2html.convert((message)),
|
|
||||||
None if response is None else mdtex2html.convert(response),
|
|
||||||
)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
gr.Chatbot.postprocess = postprocess
|
|
||||||
|
|
||||||
|
|
||||||
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
|
||||||
lines = text.split("\n")
|
|
||||||
lines = [line for line in lines if line != ""]
|
|
||||||
count = 0
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
if "```" in line:
|
|
||||||
count += 1
|
|
||||||
items = line.split("`")
|
|
||||||
if count % 2 == 1:
|
|
||||||
lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
|
|
||||||
else:
|
|
||||||
lines[i] = "<br /></code></pre>"
|
|
||||||
else:
|
|
||||||
if i > 0:
|
|
||||||
if count % 2 == 1:
|
|
||||||
line = line.replace("`", "\`")
|
|
||||||
line = line.replace("<", "<")
|
|
||||||
line = line.replace(">", ">")
|
|
||||||
line = line.replace(" ", " ")
|
|
||||||
line = line.replace("*", "*")
|
|
||||||
line = line.replace("_", "_")
|
|
||||||
line = line.replace("-", "-")
|
|
||||||
line = line.replace(".", ".")
|
|
||||||
line = line.replace("!", "!")
|
|
||||||
line = line.replace("(", "(")
|
|
||||||
line = line.replace(")", ")")
|
|
||||||
line = line.replace("$", "$")
|
|
||||||
lines[i] = "<br />" + line
|
|
||||||
text = "".join(lines)
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
||||||
chatbot.append((parse_text(query), ""))
|
chatbot.append((query, ""))
|
||||||
|
|
||||||
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
|
||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
|
|||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
response += new_text
|
response += new_text
|
||||||
new_history = history + [(query, response)]
|
new_history = history + [(query, response)]
|
||||||
chatbot[-1] = (parse_text(query), parse_text(response))
|
chatbot[-1] = (query, response)
|
||||||
yield chatbot, new_history
|
yield chatbot, new_history
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user