mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
		
							parent
							
								
									fa06b168ab
								
							
						
					
					
						commit
						6261fb362a
					
				
							
								
								
									
										19
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								README.md
									
									
									
									
									
								
							@ -95,7 +95,7 @@ huggingface-cli login
 | 
			
		||||
- Python 3.8+ and PyTorch 1.13.1+
 | 
			
		||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
 | 
			
		||||
- jieba, rouge-chinese and nltk (used at evaluation)
 | 
			
		||||
- gradio and mdtex2html (used in web_demo.py)
 | 
			
		||||
- gradio and matplotlib (used in web_demo.py)
 | 
			
		||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
 | 
			
		||||
 | 
			
		||||
And **powerful GPUs**!
 | 
			
		||||
@ -137,7 +137,8 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \
 | 
			
		||||
### (Continually) Pre-Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage pt \
 | 
			
		||||
    --model_name_or_path path_to_your_model \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --dataset wiki_demo \
 | 
			
		||||
@ -158,7 +159,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
 | 
			
		||||
### Supervised Fine-Tuning
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage sft \
 | 
			
		||||
    --model_name_or_path path_to_your_model \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --dataset alpaca_gpt4_en \
 | 
			
		||||
@ -179,7 +181,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
 | 
			
		||||
### Reward Model Training
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage rm \
 | 
			
		||||
    --model_name_or_path path_to_your_model \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --dataset comparison_gpt4_en \
 | 
			
		||||
@ -199,7 +202,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
 | 
			
		||||
### PPO Training (RLHF)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage ppo \
 | 
			
		||||
    --model_name_or_path path_to_your_model \
 | 
			
		||||
    --do_train \
 | 
			
		||||
    --dataset alpaca_gpt4_en \
 | 
			
		||||
@ -222,7 +226,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
accelerate config # configure the environment
 | 
			
		||||
accelerate launch src/train_XX.py # arguments (same as above)
 | 
			
		||||
accelerate launch src/train_bash.py # arguments (same as above)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
 | 
			
		||||
@ -256,7 +260,8 @@ use_cpu: false
 | 
			
		||||
### Evaluation (BLEU and ROUGE_CHINESE)
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
 | 
			
		||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
 | 
			
		||||
    --stage pt \
 | 
			
		||||
    --model_name_or_path path_to_your_model \
 | 
			
		||||
    --do_eval \
 | 
			
		||||
    --dataset alpaca_gpt4_en \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
			
		||||
[build-system]
 | 
			
		||||
requires = ["setuptools>=61.0"]
 | 
			
		||||
build-backend = "setuptools.build_meta"
 | 
			
		||||
@ -8,8 +8,9 @@ sentencepiece
 | 
			
		||||
jieba
 | 
			
		||||
rouge-chinese
 | 
			
		||||
nltk
 | 
			
		||||
gradio
 | 
			
		||||
mdtex2html
 | 
			
		||||
gradio>=3.36.0
 | 
			
		||||
uvicorn
 | 
			
		||||
pydantic==1.10.7
 | 
			
		||||
fastapi
 | 
			
		||||
sse-starlette
 | 
			
		||||
matplotlib
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										55
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								setup.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,55 @@
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
from setuptools import setup, find_packages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_version():
 | 
			
		||||
    with open(os.path.join("src", "llmtuner", "__init__.py"), "r", encoding="utf-8") as f:
 | 
			
		||||
        file_content = f.read()
 | 
			
		||||
        pattern = r"{0}\W*=\W*\"([^\"]+)\"".format("__version__")
 | 
			
		||||
        version, = re.findall(pattern, file_content)
 | 
			
		||||
        return version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_requires():
 | 
			
		||||
    with open("requirements.txt", "r", encoding="utf-8") as f:
 | 
			
		||||
        file_content = f.read()
 | 
			
		||||
        lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
 | 
			
		||||
        return lines
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    setup(
 | 
			
		||||
        name="llmtuner",
 | 
			
		||||
        version=get_version(),
 | 
			
		||||
        author="hiyouga",
 | 
			
		||||
        author_email="hiyouga" "@" "buaa.edu.cn",
 | 
			
		||||
        description="Easy-to-use fine-tuning framework using PEFT",
 | 
			
		||||
        long_description=open("README.md", "r", encoding="utf-8").read(),
 | 
			
		||||
        long_description_content_type="text/markdown",
 | 
			
		||||
        keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
 | 
			
		||||
        license="Apache 2.0 License",
 | 
			
		||||
        url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
 | 
			
		||||
        package_dir={"": "src"},
 | 
			
		||||
        packages=find_packages("src"),
 | 
			
		||||
        python_requires=">=3.8.0",
 | 
			
		||||
        install_requires=get_requires(),
 | 
			
		||||
        classifiers=[
 | 
			
		||||
            "Development Status :: 3 - Alpha",
 | 
			
		||||
            "Intended Audience :: Developers",
 | 
			
		||||
            "Intended Audience :: Education",
 | 
			
		||||
            "Intended Audience :: Science/Research",
 | 
			
		||||
            "License :: OSI Approved :: Apache Software License",
 | 
			
		||||
            "Operating System :: OS Independent",
 | 
			
		||||
            "Programming Language :: Python :: 3",
 | 
			
		||||
            "Programming Language :: Python :: 3.8",
 | 
			
		||||
            "Programming Language :: Python :: 3.9",
 | 
			
		||||
            "Programming Language :: Python :: 3.10",
 | 
			
		||||
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										218
									
								
								src/api_demo.py
									
									
									
									
									
								
							
							
						
						
									
										218
									
								
								src/api_demo.py
									
									
									
									
									
								
							@ -4,225 +4,11 @@
 | 
			
		||||
# Visit http://localhost:8000/docs for document.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
import torch
 | 
			
		||||
import uvicorn
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from transformers import TextIteratorStreamer
 | 
			
		||||
from sse_starlette import EventSourceResponse
 | 
			
		||||
from typing import Any, Dict, List, Literal, Optional
 | 
			
		||||
 | 
			
		||||
from utils import (
 | 
			
		||||
    Template,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_infer_args,
 | 
			
		||||
    get_logits_processor
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(app: FastAPI): # collects GPU memory
 | 
			
		||||
    yield
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        torch.cuda.ipc_collect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = FastAPI(lifespan=lifespan)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
    CORSMiddleware,
 | 
			
		||||
    allow_origins=["*"],
 | 
			
		||||
    allow_credentials=True,
 | 
			
		||||
    allow_methods=["*"],
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelCard(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    object: Optional[str] = "model"
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    owned_by: Optional[str] = "owner"
 | 
			
		||||
    root: Optional[str] = None
 | 
			
		||||
    parent: Optional[str] = None
 | 
			
		||||
    permission: Optional[list] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelList(BaseModel):
 | 
			
		||||
    object: Optional[str] = "list"
 | 
			
		||||
    data: Optional[List[ModelCard]] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatMessage(BaseModel):
 | 
			
		||||
    role: Literal["user", "assistant", "system"]
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaMessage(BaseModel):
 | 
			
		||||
    role: Optional[Literal["user", "assistant", "system"]] = None
 | 
			
		||||
    content: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionRequest(BaseModel):
 | 
			
		||||
    model: str
 | 
			
		||||
    messages: List[ChatMessage]
 | 
			
		||||
    temperature: Optional[float] = None
 | 
			
		||||
    top_p: Optional[float] = None
 | 
			
		||||
    n: Optional[int] = 1
 | 
			
		||||
    max_tokens: Optional[int] = None
 | 
			
		||||
    stream: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    message: ChatMessage
 | 
			
		||||
    finish_reason: Literal["stop", "length"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseStreamChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    delta: DeltaMessage
 | 
			
		||||
    finish_reason: Optional[Literal["stop", "length"]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseUsage(BaseModel):
 | 
			
		||||
    prompt_tokens: int
 | 
			
		||||
    completion_tokens: int
 | 
			
		||||
    total_tokens: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponse(BaseModel):
 | 
			
		||||
    id: Optional[str] = "chatcmpl-default"
 | 
			
		||||
    object: Literal["chat.completion"]
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    model: str
 | 
			
		||||
    choices: List[ChatCompletionResponseChoice]
 | 
			
		||||
    usage: ChatCompletionResponseUsage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionStreamResponse(BaseModel):
 | 
			
		||||
    id: Optional[str] = "chatcmpl-default"
 | 
			
		||||
    object: Literal["chat.completion.chunk"]
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    model: str
 | 
			
		||||
    choices: List[ChatCompletionResponseStreamChoice]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/v1/models", response_model=ModelList)
 | 
			
		||||
async def list_models():
 | 
			
		||||
    global model_args
 | 
			
		||||
    model_card = ModelCard(id="gpt-3.5-turbo")
 | 
			
		||||
    return ModelList(data=[model_card])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
 | 
			
		||||
async def create_chat_completion(request: ChatCompletionRequest):
 | 
			
		||||
    global model, tokenizer, source_prefix, generating_args
 | 
			
		||||
 | 
			
		||||
    if request.messages[-1].role != "user":
 | 
			
		||||
        raise HTTPException(status_code=400, detail="Invalid request")
 | 
			
		||||
    query = request.messages[-1].content
 | 
			
		||||
 | 
			
		||||
    prev_messages = request.messages[:-1]
 | 
			
		||||
    if len(prev_messages) > 0 and prev_messages[0].role == "system":
 | 
			
		||||
        prefix = prev_messages.pop(0).content
 | 
			
		||||
    else:
 | 
			
		||||
        prefix = source_prefix
 | 
			
		||||
 | 
			
		||||
    history = []
 | 
			
		||||
    if len(prev_messages) % 2 == 0:
 | 
			
		||||
        for i in range(0, len(prev_messages), 2):
 | 
			
		||||
            if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
 | 
			
		||||
                history.append([prev_messages[i].content, prev_messages[i+1].content])
 | 
			
		||||
 | 
			
		||||
    inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
 | 
			
		||||
    inputs = inputs.to(model.device)
 | 
			
		||||
 | 
			
		||||
    gen_kwargs = generating_args.to_dict()
 | 
			
		||||
    gen_kwargs.update({
 | 
			
		||||
        "input_ids": inputs["input_ids"],
 | 
			
		||||
        "temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
 | 
			
		||||
        "top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
 | 
			
		||||
        "logits_processor": get_logits_processor()
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
    if request.max_tokens:
 | 
			
		||||
        gen_kwargs.pop("max_length", None)
 | 
			
		||||
        gen_kwargs["max_new_tokens"] = request.max_tokens
 | 
			
		||||
 | 
			
		||||
    if request.stream:
 | 
			
		||||
        generate = predict(gen_kwargs, request.model)
 | 
			
		||||
        return EventSourceResponse(generate, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
    generation_output = model.generate(**gen_kwargs)
 | 
			
		||||
    outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
 | 
			
		||||
    response = tokenizer.decode(outputs, skip_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
    usage = ChatCompletionResponseUsage(
 | 
			
		||||
        prompt_tokens=len(inputs["input_ids"][0]),
 | 
			
		||||
        completion_tokens=len(outputs),
 | 
			
		||||
        total_tokens=len(inputs["input_ids"][0]) + len(outputs)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseChoice(
 | 
			
		||||
        index=0,
 | 
			
		||||
        message=ChatMessage(role="assistant", content=response),
 | 
			
		||||
        finish_reason="stop"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
 | 
			
		||||
    global model, tokenizer
 | 
			
		||||
 | 
			
		||||
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 | 
			
		||||
    gen_kwargs["streamer"] = streamer
 | 
			
		||||
 | 
			
		||||
    thread = Thread(target=model.generate, kwargs=gen_kwargs)
 | 
			
		||||
    thread.start()
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
        index=0,
 | 
			
		||||
        delta=DeltaMessage(role="assistant"),
 | 
			
		||||
        finish_reason=None
 | 
			
		||||
    )
 | 
			
		||||
    chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
    for new_text in streamer:
 | 
			
		||||
        if len(new_text) == 0:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
            index=0,
 | 
			
		||||
            delta=DeltaMessage(content=new_text),
 | 
			
		||||
            finish_reason=None
 | 
			
		||||
        )
 | 
			
		||||
        chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
        yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
    choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
        index=0,
 | 
			
		||||
        delta=DeltaMessage(),
 | 
			
		||||
        finish_reason="stop"
 | 
			
		||||
    )
 | 
			
		||||
    chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
    yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
    yield "[DONE]"
 | 
			
		||||
from llmtuner import create_app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
    source_prefix = data_args.source_prefix if data_args.source_prefix else ""
 | 
			
		||||
 | 
			
		||||
    app = create_app()
 | 
			
		||||
    uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
 | 
			
		||||
 | 
			
		||||
@ -2,21 +2,15 @@
 | 
			
		||||
# Implements stream chat in command line for fine-tuned models.
 | 
			
		||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from utils import (
 | 
			
		||||
    Template,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_infer_args,
 | 
			
		||||
    get_logits_processor
 | 
			
		||||
)
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from transformers import TextIteratorStreamer
 | 
			
		||||
 | 
			
		||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args)
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = get_infer_args()
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
    source_prefix = data_args.source_prefix if data_args.source_prefix else ""
 | 
			
		||||
 | 
			
		||||
@ -2,14 +2,12 @@
 | 
			
		||||
# Exports the fine-tuned model.
 | 
			
		||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from utils import load_pretrained, prepare_args
 | 
			
		||||
from llmtuner import get_train_args, load_model_and_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args)
 | 
			
		||||
    model_args, _, training_args, finetuning_args, _ = get_train_args()
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
 | 
			
		||||
    model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
 | 
			
		||||
    tokenizer.save_pretrained(training_args.output_dir)
 | 
			
		||||
    print("model and tokenizer have been saved at:", training_args.output_dir)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								src/llmtuner/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								src/llmtuner/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
from llmtuner.api import create_app
 | 
			
		||||
from llmtuner.extras.misc import get_logits_processor
 | 
			
		||||
from llmtuner.extras.template import Template
 | 
			
		||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__version__ = "0.0.9"
 | 
			
		||||
							
								
								
									
										1
									
								
								src/llmtuner/api/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/llmtuner/api/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from llmtuner.api.app import create_app
 | 
			
		||||
							
								
								
									
										152
									
								
								src/llmtuner/api/app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								src/llmtuner/api/app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,152 @@
 | 
			
		||||
import uvicorn
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from transformers import TextIteratorStreamer
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from sse_starlette import EventSourceResponse
 | 
			
		||||
from typing import Any, Dict
 | 
			
		||||
 | 
			
		||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
 | 
			
		||||
from llmtuner.extras.misc import get_logits_processor, torch_gc
 | 
			
		||||
from llmtuner.extras.template import Template
 | 
			
		||||
from llmtuner.api.protocol import (
 | 
			
		||||
    ModelCard,
 | 
			
		||||
    ModelList,
 | 
			
		||||
    ChatMessage,
 | 
			
		||||
    DeltaMessage,
 | 
			
		||||
    ChatCompletionRequest,
 | 
			
		||||
    ChatCompletionResponse,
 | 
			
		||||
    ChatCompletionStreamResponse,
 | 
			
		||||
    ChatCompletionResponseChoice,
 | 
			
		||||
    ChatCompletionResponseStreamChoice,
 | 
			
		||||
    ChatCompletionResponseUsage
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(app: FastAPI): # collects GPU memory
 | 
			
		||||
    yield
 | 
			
		||||
    torch_gc()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_app():
 | 
			
		||||
    model_args, data_args, finetuning_args, generating_args = get_infer_args()
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
    source_prefix = data_args.source_prefix if data_args.source_prefix else ""
 | 
			
		||||
 | 
			
		||||
    app = FastAPI(lifespan=lifespan)
 | 
			
		||||
 | 
			
		||||
    app.add_middleware(
 | 
			
		||||
        CORSMiddleware,
 | 
			
		||||
        allow_origins=["*"],
 | 
			
		||||
        allow_credentials=True,
 | 
			
		||||
        allow_methods=["*"],
 | 
			
		||||
        allow_headers=["*"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    @app.get("/v1/models", response_model=ModelList)
 | 
			
		||||
    async def list_models():
 | 
			
		||||
        global model_args
 | 
			
		||||
        model_card = ModelCard(id="gpt-3.5-turbo")
 | 
			
		||||
        return ModelList(data=[model_card])
 | 
			
		||||
 | 
			
		||||
    @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
 | 
			
		||||
    async def create_chat_completion(request: ChatCompletionRequest):
 | 
			
		||||
        if request.messages[-1].role != "user":
 | 
			
		||||
            raise HTTPException(status_code=400, detail="Invalid request")
 | 
			
		||||
        query = request.messages[-1].content
 | 
			
		||||
 | 
			
		||||
        prev_messages = request.messages[:-1]
 | 
			
		||||
        if len(prev_messages) > 0 and prev_messages[0].role == "system":
 | 
			
		||||
            prefix = prev_messages.pop(0).content
 | 
			
		||||
        else:
 | 
			
		||||
            prefix = source_prefix
 | 
			
		||||
 | 
			
		||||
        history = []
 | 
			
		||||
        if len(prev_messages) % 2 == 0:
 | 
			
		||||
            for i in range(0, len(prev_messages), 2):
 | 
			
		||||
                if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
 | 
			
		||||
                    history.append([prev_messages[i].content, prev_messages[i+1].content])
 | 
			
		||||
 | 
			
		||||
        inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
 | 
			
		||||
        inputs = inputs.to(model.device)
 | 
			
		||||
 | 
			
		||||
        gen_kwargs = generating_args.to_dict()
 | 
			
		||||
        gen_kwargs.update({
 | 
			
		||||
            "input_ids": inputs["input_ids"],
 | 
			
		||||
            "temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
 | 
			
		||||
            "top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
 | 
			
		||||
            "logits_processor": get_logits_processor()
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
        if request.max_tokens:
 | 
			
		||||
            gen_kwargs.pop("max_length", None)
 | 
			
		||||
            gen_kwargs["max_new_tokens"] = request.max_tokens
 | 
			
		||||
 | 
			
		||||
        if request.stream:
 | 
			
		||||
            generate = predict(gen_kwargs, request.model)
 | 
			
		||||
            return EventSourceResponse(generate, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
        generation_output = model.generate(**gen_kwargs)
 | 
			
		||||
        outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
 | 
			
		||||
        response = tokenizer.decode(outputs, skip_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
        usage = ChatCompletionResponseUsage(
 | 
			
		||||
            prompt_tokens=len(inputs["input_ids"][0]),
 | 
			
		||||
            completion_tokens=len(outputs),
 | 
			
		||||
            total_tokens=len(inputs["input_ids"][0]) + len(outputs)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        choice_data = ChatCompletionResponseChoice(
 | 
			
		||||
            index=0,
 | 
			
		||||
            message=ChatMessage(role="assistant", content=response),
 | 
			
		||||
            finish_reason="stop"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
 | 
			
		||||
 | 
			
		||||
    async def predict(gen_kwargs: Dict[str, Any], model_id: str):
 | 
			
		||||
        streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 | 
			
		||||
        gen_kwargs["streamer"] = streamer
 | 
			
		||||
 | 
			
		||||
        thread = Thread(target=model.generate, kwargs=gen_kwargs)
 | 
			
		||||
        thread.start()
 | 
			
		||||
 | 
			
		||||
        choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
            index=0,
 | 
			
		||||
            delta=DeltaMessage(role="assistant"),
 | 
			
		||||
            finish_reason=None
 | 
			
		||||
        )
 | 
			
		||||
        chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
        yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
        for new_text in streamer:
 | 
			
		||||
            if len(new_text) == 0:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
                index=0,
 | 
			
		||||
                delta=DeltaMessage(content=new_text),
 | 
			
		||||
                finish_reason=None
 | 
			
		||||
            )
 | 
			
		||||
            chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
            yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
        choice_data = ChatCompletionResponseStreamChoice(
 | 
			
		||||
            index=0,
 | 
			
		||||
            delta=DeltaMessage(),
 | 
			
		||||
            finish_reason="stop"
 | 
			
		||||
        )
 | 
			
		||||
        chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
 | 
			
		||||
        yield chunk.json(exclude_unset=True, ensure_ascii=False)
 | 
			
		||||
        yield "[DONE]"
 | 
			
		||||
 | 
			
		||||
    return app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    app = create_app()
 | 
			
		||||
    uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
 | 
			
		||||
							
								
								
									
										73
									
								
								src/llmtuner/api/protocol.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								src/llmtuner/api/protocol.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,73 @@
 | 
			
		||||
import time
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from typing import List, Literal, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelCard(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    object: Optional[str] = "model"
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    owned_by: Optional[str] = "owner"
 | 
			
		||||
    root: Optional[str] = None
 | 
			
		||||
    parent: Optional[str] = None
 | 
			
		||||
    permission: Optional[list] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelList(BaseModel):
 | 
			
		||||
    object: Optional[str] = "list"
 | 
			
		||||
    data: Optional[List[ModelCard]] = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatMessage(BaseModel):
 | 
			
		||||
    role: Literal["user", "assistant", "system"]
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeltaMessage(BaseModel):
 | 
			
		||||
    role: Optional[Literal["user", "assistant", "system"]] = None
 | 
			
		||||
    content: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionRequest(BaseModel):
 | 
			
		||||
    model: str
 | 
			
		||||
    messages: List[ChatMessage]
 | 
			
		||||
    temperature: Optional[float] = None
 | 
			
		||||
    top_p: Optional[float] = None
 | 
			
		||||
    n: Optional[int] = 1
 | 
			
		||||
    max_tokens: Optional[int] = None
 | 
			
		||||
    stream: Optional[bool] = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    message: ChatMessage
 | 
			
		||||
    finish_reason: Literal["stop", "length"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseStreamChoice(BaseModel):
 | 
			
		||||
    index: int
 | 
			
		||||
    delta: DeltaMessage
 | 
			
		||||
    finish_reason: Optional[Literal["stop", "length"]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponseUsage(BaseModel):
 | 
			
		||||
    prompt_tokens: int
 | 
			
		||||
    completion_tokens: int
 | 
			
		||||
    total_tokens: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionResponse(BaseModel):
 | 
			
		||||
    id: Optional[str] = "chatcmpl-default"
 | 
			
		||||
    object: Literal["chat.completion"]
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    model: str
 | 
			
		||||
    choices: List[ChatCompletionResponseChoice]
 | 
			
		||||
    usage: ChatCompletionResponseUsage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionStreamResponse(BaseModel):
 | 
			
		||||
    id: Optional[str] = "chatcmpl-default"
 | 
			
		||||
    object: Literal["chat.completion.chunk"]
 | 
			
		||||
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
 | 
			
		||||
    model: str
 | 
			
		||||
    choices: List[ChatCompletionResponseStreamChoice]
 | 
			
		||||
							
								
								
									
										2
									
								
								src/llmtuner/dsets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								src/llmtuner/dsets/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
from llmtuner.dsets.loader import get_dataset
 | 
			
		||||
from llmtuner.dsets.preprocess import preprocess_dataset
 | 
			
		||||
							
								
								
									
										63
									
								
								src/llmtuner/dsets/callbacks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								src/llmtuner/dsets/callbacks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    TrainerCallback,
 | 
			
		||||
    TrainerControl,
 | 
			
		||||
    TrainerState,
 | 
			
		||||
    TrainingArguments
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogCallback(TrainerCallback):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, runner=None):
 | 
			
		||||
        self.runner = runner
 | 
			
		||||
        self.start_time = time.time()
 | 
			
		||||
        self.tracker = {}
 | 
			
		||||
 | 
			
		||||
    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of a training step. If using gradient accumulation, one training step
 | 
			
		||||
        might take several inputs.
 | 
			
		||||
        """
 | 
			
		||||
        if self.runner is not None and self.runner.aborted:
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of an substep during gradient accumulation.
 | 
			
		||||
        """
 | 
			
		||||
        if self.runner is not None and self.runner.aborted:
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after logging the last logs.
 | 
			
		||||
        """
 | 
			
		||||
        if "loss" not in state.log_history[-1]:
 | 
			
		||||
            return
 | 
			
		||||
        cur_time = time.time()
 | 
			
		||||
        cur_steps = state.log_history[-1].get("step")
 | 
			
		||||
        elapsed_time = cur_time - self.start_time
 | 
			
		||||
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
 | 
			
		||||
        remaining_steps = state.max_steps - cur_steps
 | 
			
		||||
        remaining_time = remaining_steps * avg_time_per_step
 | 
			
		||||
        self.tracker = {
 | 
			
		||||
            "current_steps": cur_steps,
 | 
			
		||||
            "total_steps": state.max_steps,
 | 
			
		||||
            "loss": state.log_history[-1].get("loss", None),
 | 
			
		||||
            "reward": state.log_history[-1].get("reward", None),
 | 
			
		||||
            "learning_rate": state.log_history[-1].get("learning_rate", None),
 | 
			
		||||
            "epoch": state.log_history[-1].get("epoch", None),
 | 
			
		||||
            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
 | 
			
		||||
            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
 | 
			
		||||
            "remaining_time": str(timedelta(seconds=int(remaining_time)))
 | 
			
		||||
        }
 | 
			
		||||
        os.makedirs(args.output_dir, exist_ok=True)
 | 
			
		||||
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
 | 
			
		||||
            f.write(json.dumps(self.tracker) + "\n")
 | 
			
		||||
							
								
								
									
										106
									
								
								src/llmtuner/dsets/loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								src/llmtuner/dsets/loader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
			
		||||
import os
 | 
			
		||||
import hashlib
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from datasets import Dataset, concatenate_datasets, load_dataset
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.hparams import ModelArguments, DataArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dataset(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    data_args: DataArguments
 | 
			
		||||
) -> Dataset:
 | 
			
		||||
 | 
			
		||||
    def checksum(file_path, hash):
 | 
			
		||||
        with open(file_path, "rb") as datafile:
 | 
			
		||||
            binary_data = datafile.read()
 | 
			
		||||
        sha1 = hashlib.sha1(binary_data).hexdigest()
 | 
			
		||||
        if sha1 != hash:
 | 
			
		||||
            logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
 | 
			
		||||
 | 
			
		||||
    ext2type = {
 | 
			
		||||
        "csv": "csv",
 | 
			
		||||
        "json": "json",
 | 
			
		||||
        "jsonl": "json",
 | 
			
		||||
        "txt": "text"
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    max_samples = data_args.max_samples
 | 
			
		||||
    all_datasets: List[Dataset] = [] # support multiple datasets
 | 
			
		||||
 | 
			
		||||
    for dataset_attr in data_args.dataset_list:
 | 
			
		||||
 | 
			
		||||
        logger.info("Loading dataset {}...".format(dataset_attr))
 | 
			
		||||
 | 
			
		||||
        if dataset_attr.load_from == "hf_hub":
 | 
			
		||||
            data_path = dataset_attr.dataset_name
 | 
			
		||||
            data_files = None
 | 
			
		||||
        elif dataset_attr.load_from == "script":
 | 
			
		||||
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
 | 
			
		||||
            data_files = None
 | 
			
		||||
        elif dataset_attr.load_from == "file":
 | 
			
		||||
            data_path = None
 | 
			
		||||
            data_files: List[str] = []
 | 
			
		||||
 | 
			
		||||
            if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                    data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
 | 
			
		||||
 | 
			
		||||
                    if data_path is None:
 | 
			
		||||
                        data_path = ext2type.get(data_files[0].split(".")[-1], None)
 | 
			
		||||
                    else:
 | 
			
		||||
                        assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
 | 
			
		||||
            elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
 | 
			
		||||
                data_path = ext2type.get(data_files[0].split(".")[-1], None)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError("File not found.")
 | 
			
		||||
 | 
			
		||||
            assert data_path, "File extension must be txt, csv, json or jsonl."
 | 
			
		||||
 | 
			
		||||
            if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
 | 
			
		||||
                checksum(data_files[0], dataset_attr.dataset_sha1)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        raw_datasets = load_dataset(
 | 
			
		||||
            data_path,
 | 
			
		||||
            data_files=data_files,
 | 
			
		||||
            cache_dir=model_args.cache_dir,
 | 
			
		||||
            use_auth_token=True if model_args.use_auth_token else None
 | 
			
		||||
        )
 | 
			
		||||
        dataset = raw_datasets[data_args.split]
 | 
			
		||||
 | 
			
		||||
        if max_samples is not None:
 | 
			
		||||
            max_samples_temp = min(len(dataset), max_samples)
 | 
			
		||||
            dataset = dataset.select(range(max_samples_temp))
 | 
			
		||||
 | 
			
		||||
        dummy_data = [None] * len(dataset)
 | 
			
		||||
        prefix_data = [dataset_attr.source_prefix] * len(dataset)
 | 
			
		||||
        for column_name, target_name in [
 | 
			
		||||
            ("prompt_column", "prompt"),
 | 
			
		||||
            ("query_column", "query"),
 | 
			
		||||
            ("response_column", "response"),
 | 
			
		||||
            ("history_column", "history")
 | 
			
		||||
        ]: # every dataset will have 4 columns same as each other
 | 
			
		||||
            if getattr(dataset_attr, column_name) != target_name:
 | 
			
		||||
                if getattr(dataset_attr, column_name):
 | 
			
		||||
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
 | 
			
		||||
                else: # None or empty string
 | 
			
		||||
                    dataset = dataset.add_column(target_name, dummy_data)
 | 
			
		||||
        dataset = dataset.add_column("prefix", prefix_data)
 | 
			
		||||
        all_datasets.append(dataset)
 | 
			
		||||
 | 
			
		||||
    if len(data_args.dataset_list) == 1:
 | 
			
		||||
        all_datasets = all_datasets[0]
 | 
			
		||||
    else:
 | 
			
		||||
        all_datasets = concatenate_datasets(all_datasets)
 | 
			
		||||
 | 
			
		||||
    return all_datasets
 | 
			
		||||
							
								
								
									
										172
									
								
								src/llmtuner/dsets/preprocess.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								src/llmtuner/dsets/preprocess.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,172 @@
 | 
			
		||||
from typing import Literal
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
from datasets import Dataset
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llmtuner.extras.template import Template
 | 
			
		||||
from llmtuner.hparams import DataArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_dataset(
 | 
			
		||||
    dataset: Dataset,
 | 
			
		||||
    tokenizer: PreTrainedTokenizer,
 | 
			
		||||
    data_args: DataArguments,
 | 
			
		||||
    training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
    stage: Literal["pt", "sft", "rm", "ppo"]
 | 
			
		||||
) -> Dataset:
 | 
			
		||||
 | 
			
		||||
    column_names = list(dataset.column_names)
 | 
			
		||||
    prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
 | 
			
		||||
    # support question with a single answer or multiple answers
 | 
			
		||||
    def get_dialog(examples):
 | 
			
		||||
        for i in range(len(examples["prompt"])):
 | 
			
		||||
            if examples["prompt"][i] and examples["response"][i]:
 | 
			
		||||
                query, answer = examples["prompt"][i], examples["response"][i]
 | 
			
		||||
                query = query + "\n" + examples["query"][i] if examples["query"][i] else query
 | 
			
		||||
                prefix = examples["prefix"][i] if examples["prefix"][i] else ""
 | 
			
		||||
                dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
 | 
			
		||||
                yield dialog
 | 
			
		||||
 | 
			
		||||
    def preprocess_pretrain_dataset(examples):
 | 
			
		||||
        # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
 | 
			
		||||
        text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
 | 
			
		||||
        concatenated_ids = list(chain(*text_ids))
 | 
			
		||||
        total_length = len(concatenated_ids)
 | 
			
		||||
        block_size = data_args.max_source_length - 1
 | 
			
		||||
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
 | 
			
		||||
        total_length = (total_length // block_size) * block_size
 | 
			
		||||
        # split by chunks of max_source_length
 | 
			
		||||
        result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
 | 
			
		||||
                  for i in range(0, total_length, block_size)]
 | 
			
		||||
        return {
 | 
			
		||||
            "input_ids": result,
 | 
			
		||||
            "labels": result.copy()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def preprocess_supervised_dataset(examples):
 | 
			
		||||
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 | 
			
		||||
        # for input with history, we build multiple input-label pairs just like:
 | 
			
		||||
        # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
 | 
			
		||||
        model_inputs = {"input_ids": [], "labels": []}
 | 
			
		||||
        max_length = data_args.max_source_length + data_args.max_target_length
 | 
			
		||||
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            input_ids, labels = [], []
 | 
			
		||||
 | 
			
		||||
            for i in range(len(dialog) // 2):
 | 
			
		||||
                source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
 | 
			
		||||
                target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
                if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                    source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
                if len(target_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                    target_ids = target_ids[:data_args.max_target_length - 1]
 | 
			
		||||
 | 
			
		||||
                if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
 | 
			
		||||
                labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
            model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
            model_inputs["labels"].append(labels)
 | 
			
		||||
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def preprocess_unsupervised_dataset(examples):
 | 
			
		||||
        # build inputs with format `<bos> X` and labels with format `<bos> Y`
 | 
			
		||||
        model_inputs = {"input_ids": [], "labels": []}
 | 
			
		||||
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            prompt, answer = "".join(dialog[:-1]), dialog[-1]
 | 
			
		||||
 | 
			
		||||
            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 | 
			
		||||
            target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
            if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
            if len(target_ids) > data_args.max_target_length:
 | 
			
		||||
                target_ids = target_ids[:data_args.max_target_length]
 | 
			
		||||
 | 
			
		||||
            model_inputs["input_ids"].append(source_ids)
 | 
			
		||||
            model_inputs["labels"].append(target_ids)
 | 
			
		||||
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def preprocess_pairwise_dataset(examples):
 | 
			
		||||
        # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
 | 
			
		||||
        model_inputs = {"accept_ids": [], "reject_ids": []}
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            prompt, answer = "".join(dialog[:-1]), dialog[-1]
 | 
			
		||||
 | 
			
		||||
            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 | 
			
		||||
            accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
 | 
			
		||||
            reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
            if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
            if len(accept_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                accept_ids = accept_ids[:data_args.max_target_length - 1]
 | 
			
		||||
            if len(reject_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                reject_ids = reject_ids[:data_args.max_target_length - 1]
 | 
			
		||||
 | 
			
		||||
            accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
 | 
			
		||||
            reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
            model_inputs["accept_ids"].append(accept_ids)
 | 
			
		||||
            model_inputs["reject_ids"].append(reject_ids)
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def print_supervised_dataset_example(example):
 | 
			
		||||
        print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
        print("label_ids:\n{}".format(example["labels"]))
 | 
			
		||||
        print("labels:\n{}".format(
 | 
			
		||||
            tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
 | 
			
		||||
                             skip_special_tokens=False)
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    def print_pairwise_dataset_example(example):
 | 
			
		||||
        print("accept_ids:\n{}".format(example["accept_ids"]))
 | 
			
		||||
        print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
 | 
			
		||||
        print("reject_ids:\n{}".format(example["reject_ids"]))
 | 
			
		||||
        print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
    def print_unsupervised_dataset_example(example):
 | 
			
		||||
        print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
    if stage == "pt":
 | 
			
		||||
        preprocess_function = preprocess_pretrain_dataset
 | 
			
		||||
    elif stage == "sft":
 | 
			
		||||
        preprocess_function = preprocess_unsupervised_dataset \
 | 
			
		||||
            if training_args.predict_with_generate else preprocess_supervised_dataset
 | 
			
		||||
    elif stage == "rm":
 | 
			
		||||
        preprocess_function = preprocess_pairwise_dataset
 | 
			
		||||
    elif stage == "ppo":
 | 
			
		||||
        preprocess_function = preprocess_unsupervised_dataset
 | 
			
		||||
 | 
			
		||||
    with training_args.main_process_first(desc="dataset map pre-processing"):
 | 
			
		||||
        dataset = dataset.map(
 | 
			
		||||
            preprocess_function,
 | 
			
		||||
            batched=True,
 | 
			
		||||
            num_proc=data_args.preprocessing_num_workers,
 | 
			
		||||
            remove_columns=column_names,
 | 
			
		||||
            load_from_cache_file=not data_args.overwrite_cache,
 | 
			
		||||
            desc="Running tokenizer on dataset"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if stage == "pt":
 | 
			
		||||
            print_unsupervised_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "sft":
 | 
			
		||||
            print_supervised_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "rm":
 | 
			
		||||
            print_pairwise_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "ppo":
 | 
			
		||||
            print_unsupervised_dataset_example(dataset[0])
 | 
			
		||||
 | 
			
		||||
        return dataset
 | 
			
		||||
							
								
								
									
										72
									
								
								src/llmtuner/extras/callbacks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								src/llmtuner/extras/callbacks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,72 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    TrainerCallback,
 | 
			
		||||
    TrainerControl,
 | 
			
		||||
    TrainerState,
 | 
			
		||||
    TrainingArguments
 | 
			
		||||
)
 | 
			
		||||
from transformers.trainer_callback import TrainerControl, TrainerState
 | 
			
		||||
from transformers.training_args import TrainingArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogCallback(TrainerCallback):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, runner=None):
 | 
			
		||||
        self.runner = runner
 | 
			
		||||
        self.tracker = {}
 | 
			
		||||
 | 
			
		||||
    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of training.
 | 
			
		||||
        """
 | 
			
		||||
        self.start_time = time.time()
 | 
			
		||||
 | 
			
		||||
    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of a training step. If using gradient accumulation, one training step
 | 
			
		||||
        might take several inputs.
 | 
			
		||||
        """
 | 
			
		||||
        if self.runner is not None and self.runner.aborted:
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of an substep during gradient accumulation.
 | 
			
		||||
        """
 | 
			
		||||
        if self.runner is not None and self.runner.aborted:
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after logging the last logs.
 | 
			
		||||
        """
 | 
			
		||||
        if "current_steps" not in state.log_history[-1]:
 | 
			
		||||
            return
 | 
			
		||||
        cur_time = time.time()
 | 
			
		||||
        cur_steps = state.log_history[-1].get("step")
 | 
			
		||||
        elapsed_time = cur_time - self.start_time
 | 
			
		||||
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
 | 
			
		||||
        remaining_steps = state.max_steps - cur_steps
 | 
			
		||||
        remaining_time = remaining_steps * avg_time_per_step
 | 
			
		||||
        self.tracker = {
 | 
			
		||||
            "current_steps": cur_steps,
 | 
			
		||||
            "total_steps": state.max_steps,
 | 
			
		||||
            "loss": state.log_history[-1].get("loss", None),
 | 
			
		||||
            "eval_loss": state.log_history[-1].get("eval_loss", None),
 | 
			
		||||
            "predict_loss": state.log_history[-1].get("predict_loss", None),
 | 
			
		||||
            "reward": state.log_history[-1].get("reward", None),
 | 
			
		||||
            "learning_rate": state.log_history[-1].get("learning_rate", None),
 | 
			
		||||
            "epoch": state.log_history[-1].get("epoch", None),
 | 
			
		||||
            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
 | 
			
		||||
            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
 | 
			
		||||
            "remaining_time": str(timedelta(seconds=int(remaining_time)))
 | 
			
		||||
        }
 | 
			
		||||
        os.makedirs(args.output_dir, exist_ok=True)
 | 
			
		||||
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
 | 
			
		||||
            f.write(json.dumps(self.tracker) + "\n")
 | 
			
		||||
							
								
								
									
										7
									
								
								src/llmtuner/extras/constants.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								src/llmtuner/extras/constants.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
IGNORE_INDEX = -100
 | 
			
		||||
 | 
			
		||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
 | 
			
		||||
 | 
			
		||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
 | 
			
		||||
 | 
			
		||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
 | 
			
		||||
							
								
								
									
										18
									
								
								src/llmtuner/extras/logging.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								src/llmtuner/extras/logging.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,18 @@
 | 
			
		||||
import sys
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: str) -> logging.Logger:
 | 
			
		||||
 | 
			
		||||
    formatter = logging.Formatter(
 | 
			
		||||
        fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 | 
			
		||||
        datefmt="%m/%d/%Y %H:%M:%S"
 | 
			
		||||
    )
 | 
			
		||||
    handler = logging.StreamHandler(sys.stdout)
 | 
			
		||||
    handler.setFormatter(formatter)
 | 
			
		||||
 | 
			
		||||
    logger = logging.getLogger(name)
 | 
			
		||||
    logger.setLevel(logging.INFO)
 | 
			
		||||
    logger.addHandler(handler)
 | 
			
		||||
 | 
			
		||||
    return logger
 | 
			
		||||
							
								
								
									
										105
									
								
								src/llmtuner/extras/misc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								src/llmtuner/extras/misc.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,105 @@
 | 
			
		||||
import torch
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.generation.utils import LogitsProcessorList
 | 
			
		||||
from transformers.generation.logits_process import LogitsProcessor
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.constants import LAYERNORM_NAMES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AverageMeter:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes and stores the average and current value.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        self.val = 0
 | 
			
		||||
        self.avg = 0
 | 
			
		||||
        self.sum = 0
 | 
			
		||||
        self.count = 0
 | 
			
		||||
 | 
			
		||||
    def update(self, val, n=1):
 | 
			
		||||
        self.val = val
 | 
			
		||||
        self.sum += val * n
 | 
			
		||||
        self.count += n
 | 
			
		||||
        self.avg = self.sum / self.count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Avoid runtime error in model.generate(do_sample=True).
 | 
			
		||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
			
		||||
        if torch.isnan(scores).any() or torch.isinf(scores).any():
 | 
			
		||||
            scores.zero_()
 | 
			
		||||
            scores[..., 0] = 1.0
 | 
			
		||||
        return scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logits_processor() -> LogitsProcessorList:
 | 
			
		||||
    logits_processor = LogitsProcessorList()
 | 
			
		||||
    logits_processor.append(InvalidScoreLogitsProcessor())
 | 
			
		||||
    return logits_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_trainable_params(model: torch.nn.Module) -> None:
 | 
			
		||||
    trainable_params, all_param = 0, 0
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        num_params = param.numel()
 | 
			
		||||
        # if using DS Zero 3 and the weights are initialized empty
 | 
			
		||||
        if num_params == 0 and hasattr(param, "ds_numel"):
 | 
			
		||||
            num_params = param.ds_numel
 | 
			
		||||
        all_param += num_params
 | 
			
		||||
        if param.requires_grad:
 | 
			
		||||
            trainable_params += num_params
 | 
			
		||||
    print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
 | 
			
		||||
                trainable_params, all_param, 100 * trainable_params / all_param))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
 | 
			
		||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
 | 
			
		||||
def prepare_model_for_training(
 | 
			
		||||
    model: PreTrainedModel,
 | 
			
		||||
    finetuning_type: str,
 | 
			
		||||
    output_embedding_layer_name: Optional[str] = "lm_head",
 | 
			
		||||
    use_gradient_checkpointing: Optional[bool] = True,
 | 
			
		||||
    layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
 | 
			
		||||
) -> PreTrainedModel:
 | 
			
		||||
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
 | 
			
		||||
            param.data = param.data.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
    if use_gradient_checkpointing:
 | 
			
		||||
        if hasattr(model, "enable_input_require_grads"):
 | 
			
		||||
            model.enable_input_require_grads()
 | 
			
		||||
        else:
 | 
			
		||||
            def make_inputs_require_grad(module, input, output):
 | 
			
		||||
                output.requires_grad_(True)
 | 
			
		||||
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
 | 
			
		||||
 | 
			
		||||
        model.gradient_checkpointing_enable()
 | 
			
		||||
        model.config.use_cache = False # turn off when gradient checkpointing is enabled
 | 
			
		||||
 | 
			
		||||
    if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
 | 
			
		||||
        output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
 | 
			
		||||
        input_dtype = output_embedding_layer.weight.dtype
 | 
			
		||||
 | 
			
		||||
        class CastOutputToFloat(torch.nn.Sequential):
 | 
			
		||||
 | 
			
		||||
            def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
                return super().forward(x.to(input_dtype)).to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
def torch_gc() -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Collects GPU memory.
 | 
			
		||||
    """
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        torch.cuda.ipc_collect()
 | 
			
		||||
							
								
								
									
										50
									
								
								src/llmtuner/extras/ploting.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								src/llmtuner/extras/ploting.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from transformers.trainer import TRAINER_STATE_NAME
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
 | 
			
		||||
    r"""
 | 
			
		||||
    EMA implementation according to TensorBoard.
 | 
			
		||||
    """
 | 
			
		||||
    last = scalars[0]
 | 
			
		||||
    smoothed = list()
 | 
			
		||||
    for next_val in scalars:
 | 
			
		||||
        smoothed_val = last * weight + (1 - weight) * next_val
 | 
			
		||||
        smoothed.append(smoothed_val)
 | 
			
		||||
        last = smoothed_val
 | 
			
		||||
    return smoothed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
 | 
			
		||||
        data = json.load(f)
 | 
			
		||||
 | 
			
		||||
    for key in keys:
 | 
			
		||||
        steps, metrics = [], []
 | 
			
		||||
        for i in range(len(data["log_history"])):
 | 
			
		||||
            if key in data["log_history"][i]:
 | 
			
		||||
                steps.append(data["log_history"][i]["step"])
 | 
			
		||||
                metrics.append(data["log_history"][i][key])
 | 
			
		||||
 | 
			
		||||
        if len(metrics) == 0:
 | 
			
		||||
            logger.warning(f"No metric {key} to plot.")
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        plt.figure()
 | 
			
		||||
        plt.plot(steps, metrics, alpha=0.4, label="original")
 | 
			
		||||
        plt.plot(steps, smooth(metrics), label="smoothed")
 | 
			
		||||
        plt.title("training {} of {}".format(key, save_dictionary))
 | 
			
		||||
        plt.xlabel("step")
 | 
			
		||||
        plt.ylabel(key)
 | 
			
		||||
        plt.legend()
 | 
			
		||||
        plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
 | 
			
		||||
        print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
 | 
			
		||||
							
								
								
									
										49
									
								
								src/llmtuner/extras/save_and_load.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								src/llmtuner/extras/save_and_load.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict
 | 
			
		||||
 | 
			
		||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
 | 
			
		||||
from transformers.modeling_utils import load_sharded_checkpoint
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
 | 
			
		||||
    state_dict = model.state_dict()
 | 
			
		||||
    filtered_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for k, v in model.named_parameters():
 | 
			
		||||
        if v.requires_grad:
 | 
			
		||||
            filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
 | 
			
		||||
 | 
			
		||||
    return filtered_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
 | 
			
		||||
    weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
 | 
			
		||||
    if os.path.exists(weights_file):
 | 
			
		||||
        model_state_dict = torch.load(weights_file, map_location="cpu")
 | 
			
		||||
        model.load_state_dict(model_state_dict, strict=False) # skip missing keys
 | 
			
		||||
    elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
 | 
			
		||||
        load_sharded_checkpoint(model, checkpoint_dir, strict=False)
 | 
			
		||||
    else:
 | 
			
		||||
        logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
 | 
			
		||||
    valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
 | 
			
		||||
    if not os.path.exists(valuehead_file):
 | 
			
		||||
        logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
 | 
			
		||||
        return False
 | 
			
		||||
    valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
 | 
			
		||||
    model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
 | 
			
		||||
    model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
 | 
			
		||||
    model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
 | 
			
		||||
    model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
 | 
			
		||||
    return True
 | 
			
		||||
							
								
								
									
										5
									
								
								src/llmtuner/hparams/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/llmtuner/hparams/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
from .data_args import DataArguments
 | 
			
		||||
from .finetuning_args import FinetuningArguments
 | 
			
		||||
from .general_args import GeneralArguments
 | 
			
		||||
from .generating_args import GeneratingArguments
 | 
			
		||||
from .model_args import ModelArguments
 | 
			
		||||
							
								
								
									
										119
									
								
								src/llmtuner/hparams/data_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								src/llmtuner/hparams/data_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,119 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DatasetAttr:
 | 
			
		||||
 | 
			
		||||
    load_from: str
 | 
			
		||||
    dataset_name: Optional[str] = None
 | 
			
		||||
    dataset_sha1: Optional[str] = None
 | 
			
		||||
    source_prefix: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self.dataset_name
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.prompt_column = "instruction"
 | 
			
		||||
        self.query_column = "input"
 | 
			
		||||
        self.response_column = "output"
 | 
			
		||||
        self.history_column = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to what data we are going to input our model for training and evaluation.
 | 
			
		||||
    """
 | 
			
		||||
    dataset: Optional[str] = field(
 | 
			
		||||
        default="alpaca_zh",
 | 
			
		||||
        metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
 | 
			
		||||
    )
 | 
			
		||||
    dataset_dir: Optional[str] = field(
 | 
			
		||||
        default="data",
 | 
			
		||||
        metadata={"help": "The name of the folder containing datasets."}
 | 
			
		||||
    )
 | 
			
		||||
    split: Optional[str] = field(
 | 
			
		||||
        default="train",
 | 
			
		||||
        metadata={"help": "Which dataset split to use for training and evaluation."}
 | 
			
		||||
    )
 | 
			
		||||
    overwrite_cache: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Overwrite the cached training and evaluation sets."}
 | 
			
		||||
    )
 | 
			
		||||
    preprocessing_num_workers: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of processes to use for the preprocessing."}
 | 
			
		||||
    )
 | 
			
		||||
    max_source_length: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum total input sequence length after tokenization."}
 | 
			
		||||
    )
 | 
			
		||||
    max_target_length: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum total output sequence length after tokenization."}
 | 
			
		||||
    )
 | 
			
		||||
    max_samples: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
 | 
			
		||||
    )
 | 
			
		||||
    eval_num_beams: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
 | 
			
		||||
    )
 | 
			
		||||
    ignore_pad_token_for_loss: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
 | 
			
		||||
    )
 | 
			
		||||
    source_prefix: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
 | 
			
		||||
    )
 | 
			
		||||
    dev_ratio: Optional[float] = field(
 | 
			
		||||
        default=0,
 | 
			
		||||
        metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[str] = field(
 | 
			
		||||
        default="default",
 | 
			
		||||
        metadata={"help": "Which template to use for constructing prompts in training and inference."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def init_for_training(self): # support mixing multiple datasets
 | 
			
		||||
        dataset_names = [ds.strip() for ds in self.dataset.split(",")]
 | 
			
		||||
        with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
 | 
			
		||||
            dataset_info = json.load(f)
 | 
			
		||||
 | 
			
		||||
        if self.source_prefix is not None:
 | 
			
		||||
            prefix_list = self.source_prefix.split("|")
 | 
			
		||||
            prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
 | 
			
		||||
            assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
 | 
			
		||||
        else:
 | 
			
		||||
            prefix_list = [None] * len(dataset_names)
 | 
			
		||||
 | 
			
		||||
        self.dataset_list: List[DatasetAttr] = []
 | 
			
		||||
        for i, name in enumerate(dataset_names):
 | 
			
		||||
            if name not in dataset_info:
 | 
			
		||||
                raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
 | 
			
		||||
 | 
			
		||||
            if "hf_hub_url" in dataset_info[name]:
 | 
			
		||||
                dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
 | 
			
		||||
            elif "script_url" in dataset_info[name]:
 | 
			
		||||
                dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
 | 
			
		||||
            else:
 | 
			
		||||
                dataset_attr = DatasetAttr(
 | 
			
		||||
                    "file",
 | 
			
		||||
                    dataset_name=dataset_info[name]["file_name"],
 | 
			
		||||
                    dataset_sha1=dataset_info[name].get("file_sha1", None)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            dataset_attr.source_prefix = prefix_list[i]
 | 
			
		||||
 | 
			
		||||
            if "columns" in dataset_info[name]:
 | 
			
		||||
                dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
 | 
			
		||||
                dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
 | 
			
		||||
                dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
 | 
			
		||||
                dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
 | 
			
		||||
 | 
			
		||||
            self.dataset_list.append(dataset_attr)
 | 
			
		||||
							
								
								
									
										78
									
								
								src/llmtuner/hparams/finetuning_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								src/llmtuner/hparams/finetuning_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,78 @@
 | 
			
		||||
import json
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FinetuningArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to which techniques we are going to fine-tuning with.
 | 
			
		||||
    """
 | 
			
		||||
    finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
 | 
			
		||||
        default="lora",
 | 
			
		||||
        metadata={"help": "Which fine-tuning method to use."}
 | 
			
		||||
    )
 | 
			
		||||
    num_hidden_layers: Optional[int] = field(
 | 
			
		||||
        default=32,
 | 
			
		||||
        metadata={"help": "Number of decoder blocks in the model. \
 | 
			
		||||
                  LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
 | 
			
		||||
                  BLOOM choices: [\"24\", \"30\", \"70\"], \
 | 
			
		||||
                  Falcon choices: [\"32\", \"60\"], \
 | 
			
		||||
                  Baichuan choices: [\"32\"]"}
 | 
			
		||||
    )
 | 
			
		||||
    num_layer_trainable: Optional[int] = field(
 | 
			
		||||
        default=3,
 | 
			
		||||
        metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
 | 
			
		||||
        default="mlp",
 | 
			
		||||
        metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
 | 
			
		||||
                  LLaMA choices: [\"mlp\", \"self_attn\"], \
 | 
			
		||||
                  BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
 | 
			
		||||
                  Baichuan choices: [\"mlp\", \"self_attn\"]"}
 | 
			
		||||
    )
 | 
			
		||||
    lora_rank: Optional[int] = field(
 | 
			
		||||
        default=8,
 | 
			
		||||
        metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_alpha: Optional[float] = field(
 | 
			
		||||
        default=32.0,
 | 
			
		||||
        metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_dropout: Optional[float] = field(
 | 
			
		||||
        default=0.1,
 | 
			
		||||
        metadata={"help": "Dropout rate for the LoRA fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_target: Optional[str] = field(
 | 
			
		||||
        default="q_proj,v_proj",
 | 
			
		||||
        metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
 | 
			
		||||
                  LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
 | 
			
		||||
                  BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
 | 
			
		||||
                  Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
 | 
			
		||||
            self.lora_target = [target.strip() for target in self.lora_target.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
 | 
			
		||||
            trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
 | 
			
		||||
        else: # fine-tuning the first n layers if num_layer_trainable < 0
 | 
			
		||||
            trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
 | 
			
		||||
 | 
			
		||||
        self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
 | 
			
		||||
 | 
			
		||||
        assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
 | 
			
		||||
 | 
			
		||||
    def save_to_json(self, json_path: str):
 | 
			
		||||
        """Saves the content of this instance in JSON format inside `json_path`."""
 | 
			
		||||
        json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
 | 
			
		||||
        with open(json_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(json_string)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_from_json(cls, json_path: str):
 | 
			
		||||
        """Creates an instance from the content of `json_path`."""
 | 
			
		||||
        with open(json_path, "r", encoding="utf-8") as f:
 | 
			
		||||
            text = f.read()
 | 
			
		||||
        return cls(**json.loads(text))
 | 
			
		||||
							
								
								
									
										13
									
								
								src/llmtuner/hparams/general_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								src/llmtuner/hparams/general_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class GeneralArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to which techniques we are going to fine-tuning with.
 | 
			
		||||
    """
 | 
			
		||||
    stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
 | 
			
		||||
        default="sft",
 | 
			
		||||
        metadata={"help": "Which stage will be performed in training."}
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										51
									
								
								src/llmtuner/hparams/generating_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								src/llmtuner/hparams/generating_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class GeneratingArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to specify the decoding parameters.
 | 
			
		||||
    """
 | 
			
		||||
    do_sample: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
 | 
			
		||||
    )
 | 
			
		||||
    temperature: Optional[float] = field(
 | 
			
		||||
        default=0.95,
 | 
			
		||||
        metadata={"help": "The value used to modulate the next token probabilities."}
 | 
			
		||||
    )
 | 
			
		||||
    top_p: Optional[float] = field(
 | 
			
		||||
        default=0.7,
 | 
			
		||||
        metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
 | 
			
		||||
    )
 | 
			
		||||
    top_k: Optional[int] = field(
 | 
			
		||||
        default=50,
 | 
			
		||||
        metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
 | 
			
		||||
    )
 | 
			
		||||
    num_beams: Optional[int] = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "Number of beams for beam search. 1 means no beam search."}
 | 
			
		||||
    )
 | 
			
		||||
    max_length: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
 | 
			
		||||
    )
 | 
			
		||||
    max_new_tokens: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
 | 
			
		||||
    )
 | 
			
		||||
    repetition_penalty: Optional[float] = field(
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
 | 
			
		||||
    )
 | 
			
		||||
    length_penalty: Optional[float] = field(
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        args = asdict(self)
 | 
			
		||||
        if args.get("max_new_tokens", None):
 | 
			
		||||
            args.pop("max_length", None)
 | 
			
		||||
        return args
 | 
			
		||||
							
								
								
									
										72
									
								
								src/llmtuner/hparams/model_args.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								src/llmtuner/hparams/model_args.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,72 @@
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
 | 
			
		||||
    """
 | 
			
		||||
    model_name_or_path: str = field(
 | 
			
		||||
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
 | 
			
		||||
    )
 | 
			
		||||
    cache_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
 | 
			
		||||
    )
 | 
			
		||||
    use_fast_tokenizer: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
 | 
			
		||||
    )
 | 
			
		||||
    use_auth_token: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
 | 
			
		||||
    )
 | 
			
		||||
    model_revision: Optional[str] = field(
 | 
			
		||||
        default="main",
 | 
			
		||||
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
 | 
			
		||||
    )
 | 
			
		||||
    padding_side: Optional[Literal["left", "right"]] = field(
 | 
			
		||||
        default="left",
 | 
			
		||||
        metadata={"help": "The side on which the model should have padding applied."}
 | 
			
		||||
    )
 | 
			
		||||
    quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the model."}
 | 
			
		||||
    )
 | 
			
		||||
    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
 | 
			
		||||
        default="nf4",
 | 
			
		||||
        metadata={"help": "Quantization data type to use in int4 training."}
 | 
			
		||||
    )
 | 
			
		||||
    double_quantization: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to use double quantization in int4 training or not."}
 | 
			
		||||
    )
 | 
			
		||||
    compute_dtype: Optional[torch.dtype] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
 | 
			
		||||
    )
 | 
			
		||||
    checkpoint_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
 | 
			
		||||
    )
 | 
			
		||||
    reward_model: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
 | 
			
		||||
    )
 | 
			
		||||
    resume_lora_training: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
 | 
			
		||||
    )
 | 
			
		||||
    plot_loss: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.checkpoint_dir is not None: # support merging multiple lora weights
 | 
			
		||||
            self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.quantization_bit is not None:
 | 
			
		||||
            assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
 | 
			
		||||
							
								
								
									
										5
									
								
								src/llmtuner/tuner/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/llmtuner/tuner/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
 | 
			
		||||
from llmtuner.tuner.pt import run_pt
 | 
			
		||||
from llmtuner.tuner.sft import run_sft
 | 
			
		||||
from llmtuner.tuner.rm import run_rm
 | 
			
		||||
from llmtuner.tuner.ppo import run_ppo
 | 
			
		||||
							
								
								
									
										2
									
								
								src/llmtuner/tuner/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								src/llmtuner/tuner/core/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
 | 
			
		||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
 | 
			
		||||
							
								
								
									
										94
									
								
								src/llmtuner/tuner/core/adapter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								src/llmtuner/tuner/core/adapter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,94 @@
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from peft import (
 | 
			
		||||
    PeftModel,
 | 
			
		||||
    TaskType,
 | 
			
		||||
    LoraConfig,
 | 
			
		||||
    get_peft_model
 | 
			
		||||
)
 | 
			
		||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.save_and_load import load_trainable_params
 | 
			
		||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_adapter(
 | 
			
		||||
    model: PreTrainedModel,
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments,
 | 
			
		||||
    is_trainable: bool,
 | 
			
		||||
    is_mergeable: bool
 | 
			
		||||
) -> PreTrainedModel:
 | 
			
		||||
    r"""
 | 
			
		||||
    Initializes the adapters.
 | 
			
		||||
 | 
			
		||||
    Support full-parameter, freeze and LoRA training.
 | 
			
		||||
 | 
			
		||||
    Note that the trainable parameters must be cast to float32.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "none" and is_trainable:
 | 
			
		||||
        raise ValueError("You cannot use finetuning_type=none while training.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "full":
 | 
			
		||||
        logger.info("Fine-tuning method: Full")
 | 
			
		||||
        model = model.float()
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "freeze":
 | 
			
		||||
        logger.info("Fine-tuning method: Freeze")
 | 
			
		||||
 | 
			
		||||
        for name, param in model.named_parameters():
 | 
			
		||||
            if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
 | 
			
		||||
                param.requires_grad_(False)
 | 
			
		||||
            else:
 | 
			
		||||
                param.data = param.data.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        if model_args.checkpoint_dir is not None:
 | 
			
		||||
            assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "lora":
 | 
			
		||||
        logger.info("Fine-tuning method: LoRA")
 | 
			
		||||
        latest_checkpoint = None
 | 
			
		||||
 | 
			
		||||
        if model_args.checkpoint_dir is not None:
 | 
			
		||||
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
 | 
			
		||||
                "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
 | 
			
		||||
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
 | 
			
		||||
                "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
 | 
			
		||||
 | 
			
		||||
            if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
 | 
			
		||||
                checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
 | 
			
		||||
            else:
 | 
			
		||||
                checkpoints_to_merge = model_args.checkpoint_dir
 | 
			
		||||
 | 
			
		||||
            for checkpoint in checkpoints_to_merge:
 | 
			
		||||
                model = PeftModel.from_pretrained(model, checkpoint)
 | 
			
		||||
                model = model.merge_and_unload()
 | 
			
		||||
 | 
			
		||||
            if len(checkpoints_to_merge) > 0:
 | 
			
		||||
                logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
 | 
			
		||||
 | 
			
		||||
            if latest_checkpoint is not None: # resume lora training or quantized inference
 | 
			
		||||
                model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
 | 
			
		||||
 | 
			
		||||
        if is_trainable and latest_checkpoint is None: # create new lora weights while training
 | 
			
		||||
            lora_config = LoraConfig(
 | 
			
		||||
                task_type=TaskType.CAUSAL_LM,
 | 
			
		||||
                inference_mode=False,
 | 
			
		||||
                r=finetuning_args.lora_rank,
 | 
			
		||||
                lora_alpha=finetuning_args.lora_alpha,
 | 
			
		||||
                lora_dropout=finetuning_args.lora_dropout,
 | 
			
		||||
                target_modules=finetuning_args.lora_target
 | 
			
		||||
            )
 | 
			
		||||
            model = get_peft_model(model, lora_config)
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
							
								
								
									
										151
									
								
								src/llmtuner/tuner/core/loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								src/llmtuner/tuner/core/loader.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,151 @@
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Literal, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoModelForCausalLM,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
    BitsAndBytesConfig
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import check_min_version
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
 | 
			
		||||
from llmtuner.extras.save_and_load import load_valuehead_params
 | 
			
		||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core.adapter import init_adapter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
check_min_version("4.29.1")
 | 
			
		||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
 | 
			
		||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
 | 
			
		||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
 | 
			
		||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model_and_tokenizer(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments,
 | 
			
		||||
    is_trainable: Optional[bool] = False,
 | 
			
		||||
    stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
 | 
			
		||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads pretrained model and tokenizer.
 | 
			
		||||
 | 
			
		||||
    Support both training and inference.
 | 
			
		||||
    """
 | 
			
		||||
    if (not is_trainable) and model_args.checkpoint_dir is None:
 | 
			
		||||
        logger.warning("Checkpoint is not found at evaluation, load the original model.")
 | 
			
		||||
        finetuning_args = FinetuningArguments(finetuning_type="none")
 | 
			
		||||
 | 
			
		||||
    assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "RM and PPO training can only be performed with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    config_kwargs = {
 | 
			
		||||
        "trust_remote_code": True,
 | 
			
		||||
        "cache_dir": model_args.cache_dir,
 | 
			
		||||
        "revision": model_args.model_revision,
 | 
			
		||||
        "use_auth_token": True if model_args.use_auth_token else None,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        model_args.model_name_or_path,
 | 
			
		||||
        use_fast=model_args.use_fast_tokenizer,
 | 
			
		||||
        padding_side=model_args.padding_side,
 | 
			
		||||
        **config_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
 | 
			
		||||
        tokenizer.pad_token_id = 0 # set as the <unk> token
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
 | 
			
		||||
    is_mergeable = True
 | 
			
		||||
 | 
			
		||||
    # Quantization configurations (using bitsandbytes library).
 | 
			
		||||
    if model_args.quantization_bit is not None:
 | 
			
		||||
        if model_args.quantization_bit == 8:
 | 
			
		||||
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
 | 
			
		||||
            config_kwargs["load_in_8bit"] = True
 | 
			
		||||
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                load_in_8bit=True,
 | 
			
		||||
                llm_int8_threshold=6.0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        elif model_args.quantization_bit == 4:
 | 
			
		||||
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
 | 
			
		||||
            require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
 | 
			
		||||
            require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
 | 
			
		||||
            require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
 | 
			
		||||
            config_kwargs["load_in_4bit"] = True
 | 
			
		||||
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                load_in_4bit=True,
 | 
			
		||||
                bnb_4bit_compute_dtype=model_args.compute_dtype,
 | 
			
		||||
                bnb_4bit_use_double_quant=model_args.double_quantization,
 | 
			
		||||
                bnb_4bit_quant_type=model_args.quantization_type
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        is_mergeable = False
 | 
			
		||||
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
 | 
			
		||||
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
 | 
			
		||||
 | 
			
		||||
    if not is_trainable: # `device_map=auto` should be used for inference only
 | 
			
		||||
        config_kwargs["device_map"] = "auto"
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
 | 
			
		||||
        model_to_load = model_args.checkpoint_dir[0]
 | 
			
		||||
    else:
 | 
			
		||||
        model_to_load = model_args.model_name_or_path
 | 
			
		||||
 | 
			
		||||
    # Load and prepare pretrained models (without valuehead).
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_to_load,
 | 
			
		||||
        config=config,
 | 
			
		||||
        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
        **config_kwargs
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Register auto class to save the custom code files.
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
 | 
			
		||||
        config.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
 | 
			
		||||
        tokenizer.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
 | 
			
		||||
        model.__class__.register_for_auto_class()
 | 
			
		||||
 | 
			
		||||
    # Initialize adapters
 | 
			
		||||
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
 | 
			
		||||
    model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
 | 
			
		||||
 | 
			
		||||
    if stage == "rm" or stage == "ppo": # add value head
 | 
			
		||||
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
 | 
			
		||||
 | 
			
		||||
        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
 | 
			
		||||
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
 | 
			
		||||
            if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
 | 
			
		||||
                model.v_head.load_state_dict({
 | 
			
		||||
                    "summary.weight": getattr(model, "reward_head_weight"),
 | 
			
		||||
                    "summary.bias": getattr(model, "reward_head_bias")
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
        if stage == "ppo": # load reward model
 | 
			
		||||
            assert is_trainable, "PPO stage cannot be performed at evaluation."
 | 
			
		||||
            assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
 | 
			
		||||
            logger.info("Load reward model from {}".format(model_args.reward_model))
 | 
			
		||||
            model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
 | 
			
		||||
            assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
 | 
			
		||||
 | 
			
		||||
    if not is_trainable:
 | 
			
		||||
        model.requires_grad_(False) # fix all model params
 | 
			
		||||
        model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
 | 
			
		||||
 | 
			
		||||
    print_trainable_params(model)
 | 
			
		||||
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
							
								
								
									
										134
									
								
								src/llmtuner/tuner/core/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								src/llmtuner/tuner/core/parser.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,134 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
import datasets
 | 
			
		||||
import transformers
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.hparams import (
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments,
 | 
			
		||||
    GeneralArguments
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_train_args(
 | 
			
		||||
    args: Optional[Dict[str, Any]] = None
 | 
			
		||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
 | 
			
		||||
 | 
			
		||||
    parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
 | 
			
		||||
 | 
			
		||||
    if args is not None:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
 | 
			
		||||
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
 | 
			
		||||
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
 | 
			
		||||
 | 
			
		||||
    # Setup logging
 | 
			
		||||
    if training_args.should_log:
 | 
			
		||||
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
 | 
			
		||||
        transformers.utils.logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
    log_level = training_args.get_process_log_level()
 | 
			
		||||
    datasets.utils.logging.set_verbosity(log_level)
 | 
			
		||||
    transformers.utils.logging.set_verbosity(log_level)
 | 
			
		||||
    transformers.utils.logging.enable_default_handler()
 | 
			
		||||
    transformers.utils.logging.enable_explicit_format()
 | 
			
		||||
 | 
			
		||||
    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
 | 
			
		||||
    data_args.init_for_training()
 | 
			
		||||
 | 
			
		||||
    assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
 | 
			
		||||
        "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
 | 
			
		||||
 | 
			
		||||
    assert not (training_args.do_train and training_args.predict_with_generate), \
 | 
			
		||||
        "`predict_with_generate` cannot be set as True while training."
 | 
			
		||||
 | 
			
		||||
    assert (not training_args.do_predict) or training_args.predict_with_generate, \
 | 
			
		||||
        "Please enable `predict_with_generate` to save model predictions."
 | 
			
		||||
 | 
			
		||||
    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "Quantization is only compatible with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
 | 
			
		||||
        else:
 | 
			
		||||
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
 | 
			
		||||
                "Quantized model only accepts a single checkpoint."
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None and (not training_args.do_train):
 | 
			
		||||
        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
 | 
			
		||||
 | 
			
		||||
    if training_args.do_train and (not training_args.fp16):
 | 
			
		||||
        logger.warning("We recommend enable fp16 mixed precision training.")
 | 
			
		||||
 | 
			
		||||
    if data_args.prompt_template == "default":
 | 
			
		||||
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
 | 
			
		||||
 | 
			
		||||
    if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
 | 
			
		||||
        logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
 | 
			
		||||
        training_args.ddp_find_unused_parameters = False
 | 
			
		||||
 | 
			
		||||
    training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None:
 | 
			
		||||
        if training_args.fp16:
 | 
			
		||||
            model_args.compute_dtype = torch.float16
 | 
			
		||||
        elif training_args.bf16:
 | 
			
		||||
            model_args.compute_dtype = torch.bfloat16
 | 
			
		||||
        else:
 | 
			
		||||
            model_args.compute_dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
    # Log on each process the small summary:
 | 
			
		||||
    logger.info(
 | 
			
		||||
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
 | 
			
		||||
        + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
 | 
			
		||||
    )
 | 
			
		||||
    logger.info(f"Training/evaluation parameters {training_args}")
 | 
			
		||||
 | 
			
		||||
    # Set seed before initializing model.
 | 
			
		||||
    transformers.set_seed(training_args.seed)
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, training_args, finetuning_args, general_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_infer_args(
 | 
			
		||||
    args: Optional[Dict[str, Any]] = None
 | 
			
		||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
 | 
			
		||||
 | 
			
		||||
    parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
 | 
			
		||||
 | 
			
		||||
    if args is not None:
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
 | 
			
		||||
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
 | 
			
		||||
    elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
 | 
			
		||||
 | 
			
		||||
    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "Quantization is only compatible with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
 | 
			
		||||
        else:
 | 
			
		||||
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
 | 
			
		||||
                "Quantized model only accepts a single checkpoint."
 | 
			
		||||
 | 
			
		||||
    if data_args.prompt_template == "default":
 | 
			
		||||
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, finetuning_args, generating_args
 | 
			
		||||
@ -1,76 +1,20 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict, Optional
 | 
			
		||||
from datetime import timedelta
 | 
			
		||||
 | 
			
		||||
from transformers import (
 | 
			
		||||
    Seq2SeqTrainer,
 | 
			
		||||
    TrainerCallback,
 | 
			
		||||
    TrainerControl,
 | 
			
		||||
    TrainerState,
 | 
			
		||||
    TrainingArguments
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from transformers import Seq2SeqTrainer
 | 
			
		||||
from transformers.trainer import TRAINING_ARGS_NAME
 | 
			
		||||
from transformers.modeling_utils import unwrap_model
 | 
			
		||||
 | 
			
		||||
from .config import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
from .other import (
 | 
			
		||||
    get_logger,
 | 
			
		||||
    get_state_dict,
 | 
			
		||||
    load_trainable_params,
 | 
			
		||||
    load_valuehead_params,
 | 
			
		||||
    FINETUNING_ARGS_NAME,
 | 
			
		||||
    VALUE_HEAD_FILE_NAME
 | 
			
		||||
)
 | 
			
		||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
 | 
			
		||||
from llmtuner.hparams import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogCallback(TrainerCallback):
 | 
			
		||||
    r"""
 | 
			
		||||
    TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
 | 
			
		||||
    The on_log function primarily collects process parameters during training, such as training loss, learning rate,
 | 
			
		||||
    and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
 | 
			
		||||
    time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
 | 
			
		||||
    purposes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.start_time = time.time()
 | 
			
		||||
 | 
			
		||||
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after logging the last logs.
 | 
			
		||||
        """
 | 
			
		||||
        if "loss" not in state.log_history[-1]:
 | 
			
		||||
            return
 | 
			
		||||
        cur_time = time.time()
 | 
			
		||||
        cur_steps = state.log_history[-1].get("step")
 | 
			
		||||
        elapsed_time = cur_time - self.start_time
 | 
			
		||||
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
 | 
			
		||||
        remaining_steps = state.max_steps - cur_steps
 | 
			
		||||
        remaining_time = remaining_steps * avg_time_per_step
 | 
			
		||||
        log_dict = {
 | 
			
		||||
            "current_steps": cur_steps,
 | 
			
		||||
            "total_steps": state.max_steps,
 | 
			
		||||
            "loss": state.log_history[-1].get("loss", None),
 | 
			
		||||
            "reward": state.log_history[-1].get("reward", None),
 | 
			
		||||
            "learning_rate": state.log_history[-1].get("learning_rate", None),
 | 
			
		||||
            "epoch": state.log_history[-1].get("epoch", None),
 | 
			
		||||
            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
 | 
			
		||||
            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
 | 
			
		||||
            "remaining_time": str(timedelta(seconds=int(remaining_time)))
 | 
			
		||||
        }
 | 
			
		||||
        os.makedirs(args.output_dir, exist_ok=True)
 | 
			
		||||
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
 | 
			
		||||
            f.write(json.dumps(log_dict) + "\n")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PeftTrainer(Seq2SeqTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
 | 
			
		||||
							
								
								
									
										1
									
								
								src/llmtuner/tuner/ppo/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/llmtuner/tuner/ppo/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from llmtuner.tuner.ppo.workflow import run_ppo
 | 
			
		||||
@ -2,77 +2,43 @@ import os
 | 
			
		||||
import math
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
 | 
			
		||||
from typing import Callable, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments, TrainerState
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
 | 
			
		||||
from trl import PPOTrainer
 | 
			
		||||
from trl.core import LengthSampler
 | 
			
		||||
 | 
			
		||||
from .peft_trainer import PeftTrainer, LogCallback
 | 
			
		||||
 | 
			
		||||
from .config import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
from .other import (
 | 
			
		||||
    AverageMeter,
 | 
			
		||||
    get_logger,
 | 
			
		||||
    get_logits_processor
 | 
			
		||||
)
 | 
			
		||||
from llmtuner.extras.callbacks import LogCallback
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
 | 
			
		||||
from llmtuner.hparams import FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core.trainer import PeftTrainer
 | 
			
		||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
 | 
			
		||||
    if target == "reward": # save default head temporarily
 | 
			
		||||
        valuehead_state_dict = model.v_head.state_dict()
 | 
			
		||||
        setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
 | 
			
		||||
        setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
 | 
			
		||||
 | 
			
		||||
    model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
 | 
			
		||||
    model.v_head.load_state_dict({
 | 
			
		||||
        "summary.weight": getattr(model, "{}_head_weight".format(target)),
 | 
			
		||||
        "summary.bias": getattr(model, "{}_head_bias".format(target))
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cast_layernorm_dtype(
 | 
			
		||||
        model: AutoModelForCausalLMWithValueHead,
 | 
			
		||||
        layer_norm_names: List[str] = ["norm", "ln_f", "ln_attn", "ln_mlp"], # for LLaMA, BLOOM and Falcon settings
 | 
			
		||||
        layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
 | 
			
		||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
 | 
			
		||||
 | 
			
		||||
    layer_norm_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
 | 
			
		||||
            if layer_norm_params is not None:
 | 
			
		||||
                param.data = layer_norm_params[name] # restore float32 weights
 | 
			
		||||
            else:
 | 
			
		||||
                layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
 | 
			
		||||
                param.data = param.data.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
    return model, layer_norm_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits PPOTrainer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
            finetuning_args: FinetuningArguments,
 | 
			
		||||
            callbacks: List[LogCallback],
 | 
			
		||||
            **kwargs
 | 
			
		||||
        self,
 | 
			
		||||
        training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
        finetuning_args: FinetuningArguments,
 | 
			
		||||
        callbacks: List[LogCallback],
 | 
			
		||||
        **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        PPOTrainer.__init__(self, **kwargs)
 | 
			
		||||
        self.args = training_args
 | 
			
		||||
        self.finetuning_args = finetuning_args
 | 
			
		||||
        self.log_callback = callbacks[0]
 | 
			
		||||
        self.state = TrainerState()
 | 
			
		||||
        self.control = TrainerControl()
 | 
			
		||||
        self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
 | 
			
		||||
 | 
			
		||||
    def ppo_train(self, max_target_length: int) -> None:
 | 
			
		||||
@ -117,8 +83,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
        steps_trained = 0
 | 
			
		||||
        loss_meter = AverageMeter()
 | 
			
		||||
        reward_meter = AverageMeter()
 | 
			
		||||
        self.log_callback.on_train_begin(self.args, self.state, self.control)
 | 
			
		||||
 | 
			
		||||
        for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
 | 
			
		||||
        for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
 | 
			
		||||
 | 
			
		||||
            for _ in range(self.config.gradient_accumulation_steps):
 | 
			
		||||
 | 
			
		||||
@ -158,6 +125,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
                loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
 | 
			
		||||
                reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
 | 
			
		||||
 | 
			
		||||
                if self.control.should_epoch_stop or self.control.should_training_stop:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                if steps_trained == len_dataloader:
 | 
			
		||||
                    dataiter = iter(self.dataloader)
 | 
			
		||||
                    steps_trained = 0
 | 
			
		||||
@ -172,20 +142,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
 | 
			
		||||
                print(logs)
 | 
			
		||||
                logs["step"] = step
 | 
			
		||||
                self.state.log_history.append(logs)
 | 
			
		||||
                self.log_callback.on_log(self.args, self.state, None)
 | 
			
		||||
                self.log_callback.on_log(self.args, self.state, self.control)
 | 
			
		||||
                loss_meter.reset()
 | 
			
		||||
                reward_meter.reset()
 | 
			
		||||
 | 
			
		||||
            if (step+1) % self.args.save_steps == 0: # save checkpoint
 | 
			
		||||
                self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
 | 
			
		||||
 | 
			
		||||
            if self.control.should_training_stop:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def generate(
 | 
			
		||||
            self,
 | 
			
		||||
            inputs: Dict[str, torch.Tensor],
 | 
			
		||||
            length_sampler: Optional[Callable] = None,
 | 
			
		||||
            return_prompt: Optional[bool] = True,
 | 
			
		||||
            **generation_kwargs,
 | 
			
		||||
        self,
 | 
			
		||||
        inputs: Dict[str, torch.Tensor],
 | 
			
		||||
        length_sampler: Optional[Callable] = None,
 | 
			
		||||
        return_prompt: Optional[bool] = True,
 | 
			
		||||
        **generation_kwargs
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        r"""
 | 
			
		||||
        Generates model's responses given queries.
 | 
			
		||||
							
								
								
									
										37
									
								
								src/llmtuner/tuner/ppo/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								src/llmtuner/tuner/ppo/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict, List, Literal, Optional, Tuple
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.constants import LAYERNORM_NAMES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
 | 
			
		||||
    if target == "reward": # save default head temporarily
 | 
			
		||||
        valuehead_state_dict = model.v_head.state_dict()
 | 
			
		||||
        setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
 | 
			
		||||
        setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
 | 
			
		||||
 | 
			
		||||
    model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
 | 
			
		||||
    model.v_head.load_state_dict({
 | 
			
		||||
        "summary.weight": getattr(model, "{}_head_weight".format(target)),
 | 
			
		||||
        "summary.bias": getattr(model, "{}_head_bias".format(target))
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cast_layernorm_dtype(
 | 
			
		||||
    model: AutoModelForCausalLMWithValueHead,
 | 
			
		||||
    layer_norm_names: List[str] = LAYERNORM_NAMES,
 | 
			
		||||
    layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
 | 
			
		||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
 | 
			
		||||
 | 
			
		||||
    layer_norm_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
 | 
			
		||||
            if layer_norm_params is not None:
 | 
			
		||||
                param.data = layer_norm_params[name] # restore float32 weights
 | 
			
		||||
            else:
 | 
			
		||||
                layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
 | 
			
		||||
                param.data = param.data.to(torch.float16)
 | 
			
		||||
 | 
			
		||||
    return model, layer_norm_state_dict
 | 
			
		||||
@ -1,36 +1,30 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Implements parameter-efficient PPO training of fine-tuned models.
 | 
			
		||||
# This code is inspired by:
 | 
			
		||||
# Inspired by:
 | 
			
		||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
from torch.optim import AdamW
 | 
			
		||||
from transformers.optimization import get_scheduler
 | 
			
		||||
from trl import PPOConfig
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
from utils import (
 | 
			
		||||
    PPOPeftTrainer,
 | 
			
		||||
    LogCallback,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_args,
 | 
			
		||||
    prepare_data,
 | 
			
		||||
    preprocess_data,
 | 
			
		||||
    plot_loss
 | 
			
		||||
)
 | 
			
		||||
from torch.optim import AdamW
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
 | 
			
		||||
from transformers.optimization import get_scheduler
 | 
			
		||||
 | 
			
		||||
from llmtuner.dsets import get_dataset, preprocess_dataset
 | 
			
		||||
from llmtuner.extras.callbacks import LogCallback
 | 
			
		||||
from llmtuner.extras.ploting import plot_loss
 | 
			
		||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core import load_model_and_tokenizer
 | 
			
		||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    # Prepare pretrained model and dataset
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
 | 
			
		||||
    dataset = prepare_data(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
 | 
			
		||||
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
 | 
			
		||||
    data_collator = DataCollatorForSeq2Seq(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        label_pad_token_id=tokenizer.pad_token_id
 | 
			
		||||
    )
 | 
			
		||||
def run_ppo(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    data_args: DataArguments,
 | 
			
		||||
    training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments
 | 
			
		||||
):
 | 
			
		||||
    dataset = get_dataset(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
 | 
			
		||||
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
 | 
			
		||||
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
 | 
			
		||||
 | 
			
		||||
    ppo_config = PPOConfig(
 | 
			
		||||
        model_name=model_args.model_name_or_path,
 | 
			
		||||
@ -72,12 +66,3 @@ def main():
 | 
			
		||||
    ppo_trainer.save_state() # must be after save_model
 | 
			
		||||
    if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
 | 
			
		||||
        plot_loss(training_args.output_dir, keys=["loss", "reward"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mp_fn(index):
 | 
			
		||||
    # For xla_spawn (TPUs)
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										1
									
								
								src/llmtuner/tuner/pt/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/llmtuner/tuner/pt/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from llmtuner.tuner.pt.workflow import run_pt
 | 
			
		||||
@ -1,31 +1,28 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Implements several parameter-efficient pre-training method.
 | 
			
		||||
# This code is inspired by
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
from utils.other import IGNORE_INDEX
 | 
			
		||||
from typing import Optional, List
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
 | 
			
		||||
 | 
			
		||||
from utils import (
 | 
			
		||||
    PeftTrainer,
 | 
			
		||||
    LogCallback,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_args,
 | 
			
		||||
    prepare_data,
 | 
			
		||||
    preprocess_data,
 | 
			
		||||
    plot_loss
 | 
			
		||||
)
 | 
			
		||||
from llmtuner.dsets import get_dataset, preprocess_dataset
 | 
			
		||||
from llmtuner.extras.callbacks import LogCallback
 | 
			
		||||
from llmtuner.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llmtuner.extras.ploting import plot_loss
 | 
			
		||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core import load_model_and_tokenizer
 | 
			
		||||
from llmtuner.tuner.core.trainer import PeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    # Prepare pretrained model and dataset
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
 | 
			
		||||
    dataset = prepare_data(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
 | 
			
		||||
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
 | 
			
		||||
def run_pt(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    data_args: DataArguments,
 | 
			
		||||
    training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments,
 | 
			
		||||
    callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
 | 
			
		||||
):
 | 
			
		||||
    dataset = get_dataset(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
 | 
			
		||||
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
 | 
			
		||||
    data_collator = DataCollatorForSeq2Seq(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
 | 
			
		||||
@ -48,7 +45,7 @@ def main():
 | 
			
		||||
        args=training_args,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        data_collator=data_collator,
 | 
			
		||||
        callbacks=[LogCallback()],
 | 
			
		||||
        callbacks=callbacks,
 | 
			
		||||
        **trainer_kwargs
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -65,21 +62,12 @@ def main():
 | 
			
		||||
    # Evaluation
 | 
			
		||||
    if training_args.do_eval:
 | 
			
		||||
        metrics = trainer.evaluate(metric_key_prefix="eval")
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            perplexity = math.exp(metrics["eval_loss"])
 | 
			
		||||
        except OverflowError:
 | 
			
		||||
            perplexity = float("inf")
 | 
			
		||||
 | 
			
		||||
        metrics["perplexity"] = perplexity
 | 
			
		||||
 | 
			
		||||
        trainer.log_metrics("eval", metrics)
 | 
			
		||||
        trainer.save_metrics("eval", metrics)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mp_fn(index):
 | 
			
		||||
    # For xla_spawn (TPUs)
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										1
									
								
								src/llmtuner/tuner/rm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/llmtuner/tuner/rm/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from llmtuner.tuner.rm.workflow import run_rm
 | 
			
		||||
							
								
								
									
										19
									
								
								src/llmtuner/tuner/rm/collator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								src/llmtuner/tuner/rm/collator.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Any, Dict, Sequence
 | 
			
		||||
from transformers import DataCollatorWithPadding
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for pairwise data.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pads batched data to the longest sequence in the batch.
 | 
			
		||||
 | 
			
		||||
        We generate 2 * n examples where the first n examples represent chosen examples and
 | 
			
		||||
        the last n examples represent rejected examples.
 | 
			
		||||
        """
 | 
			
		||||
        features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
 | 
			
		||||
        return super().__call__(features)
 | 
			
		||||
							
								
								
									
										7
									
								
								src/llmtuner/tuner/rm/metric.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								src/llmtuner/tuner/rm/metric.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Dict, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
 | 
			
		||||
    preds, _ = eval_preds
 | 
			
		||||
    return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
 | 
			
		||||
							
								
								
									
										38
									
								
								src/llmtuner/tuner/rm/trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								src/llmtuner/tuner/rm/trainer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,38 @@
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Union
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
 | 
			
		||||
from llmtuner.tuner.core.trainer import PeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PairwisePeftTrainer(PeftTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits PeftTrainer to compute pairwise loss.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.can_return_loss = True # override property to return eval_loss
 | 
			
		||||
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self,
 | 
			
		||||
        model: PreTrainedModel,
 | 
			
		||||
        inputs: Dict[str, torch.Tensor],
 | 
			
		||||
        return_outputs: Optional[bool] = False
 | 
			
		||||
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
 | 
			
		||||
 | 
			
		||||
        We use score on the EOS token to represent reward of the whole sentence.
 | 
			
		||||
 | 
			
		||||
        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
 | 
			
		||||
 | 
			
		||||
        Note that the first element will be removed from the output tuple.
 | 
			
		||||
 | 
			
		||||
        See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
 | 
			
		||||
        """
 | 
			
		||||
        batch_size = inputs["input_ids"].size(0) // 2
 | 
			
		||||
        _, _, values = model(**inputs)
 | 
			
		||||
        r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
 | 
			
		||||
        loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
 | 
			
		||||
        return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
 | 
			
		||||
@ -1,30 +1,28 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Implements parameter-efficient training of reward models.
 | 
			
		||||
# This code is inspired by:
 | 
			
		||||
# Inspired by:
 | 
			
		||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
 | 
			
		||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
 | 
			
		||||
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
 | 
			
		||||
from utils import (
 | 
			
		||||
    PairwiseDataCollatorWithPadding,
 | 
			
		||||
    PairwisePeftTrainer,
 | 
			
		||||
    LogCallback,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_args,
 | 
			
		||||
    prepare_data,
 | 
			
		||||
    preprocess_data,
 | 
			
		||||
    compute_accuracy,
 | 
			
		||||
    plot_loss
 | 
			
		||||
)
 | 
			
		||||
from llmtuner.dsets import get_dataset, preprocess_dataset
 | 
			
		||||
from llmtuner.extras.callbacks import LogCallback
 | 
			
		||||
from llmtuner.extras.ploting import plot_loss
 | 
			
		||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core import load_model_and_tokenizer
 | 
			
		||||
from llmtuner.tuner.rm.metric import compute_accuracy
 | 
			
		||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
 | 
			
		||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    # Prepare pretrained model and dataset
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
 | 
			
		||||
    dataset = prepare_data(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
 | 
			
		||||
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
 | 
			
		||||
def run_rm(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    data_args: DataArguments,
 | 
			
		||||
    training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments
 | 
			
		||||
):
 | 
			
		||||
    dataset = get_dataset(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
 | 
			
		||||
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
 | 
			
		||||
    data_collator = PairwiseDataCollatorWithPadding(tokenizer)
 | 
			
		||||
 | 
			
		||||
    training_args.remove_unused_columns = False # important for pairwise dataset
 | 
			
		||||
@ -66,12 +64,3 @@ def main():
 | 
			
		||||
        metrics = trainer.evaluate(metric_key_prefix="eval")
 | 
			
		||||
        trainer.log_metrics("eval", metrics)
 | 
			
		||||
        trainer.save_metrics("eval", metrics)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mp_fn(index):
 | 
			
		||||
    # For xla_spawn (TPUs)
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										1
									
								
								src/llmtuner/tuner/sft/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/llmtuner/tuner/sft/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
from llmtuner.tuner.sft.workflow import run_sft
 | 
			
		||||
							
								
								
									
										51
									
								
								src/llmtuner/tuner/sft/metric.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								src/llmtuner/tuner/sft/metric.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Dict, Sequence, Tuple, Union
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
import jieba
 | 
			
		||||
from rouge_chinese import Rouge
 | 
			
		||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
 | 
			
		||||
 | 
			
		||||
from llmtuner.extras.constants import IGNORE_INDEX
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ComputeMetrics:
 | 
			
		||||
    r"""
 | 
			
		||||
    Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    tokenizer: PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Uses the model predictions to compute metrics.
 | 
			
		||||
        """
 | 
			
		||||
        preds, labels = eval_preds
 | 
			
		||||
        score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
 | 
			
		||||
 | 
			
		||||
        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
 | 
			
		||||
        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
 | 
			
		||||
 | 
			
		||||
        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
 | 
			
		||||
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
        for pred, label in zip(decoded_preds, decoded_labels):
 | 
			
		||||
            hypothesis = list(jieba.cut(pred))
 | 
			
		||||
            reference = list(jieba.cut(label))
 | 
			
		||||
 | 
			
		||||
            if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
 | 
			
		||||
                result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
 | 
			
		||||
            else:
 | 
			
		||||
                rouge = Rouge()
 | 
			
		||||
                scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
 | 
			
		||||
                result = scores[0]
 | 
			
		||||
 | 
			
		||||
            for k, v in result.items():
 | 
			
		||||
                score_dict[k].append(round(v["f"] * 100, 4))
 | 
			
		||||
 | 
			
		||||
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
 | 
			
		||||
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))
 | 
			
		||||
 | 
			
		||||
        return {k: float(np.mean(v)) for k, v in score_dict.items()}
 | 
			
		||||
@ -3,65 +3,17 @@ import json
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from transformers.trainer import PredictionOutput
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
import jieba
 | 
			
		||||
from rouge_chinese import Rouge
 | 
			
		||||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
 | 
			
		||||
 | 
			
		||||
from .peft_trainer import PeftTrainer
 | 
			
		||||
 | 
			
		||||
from .other import get_logger, IGNORE_INDEX
 | 
			
		||||
from llmtuner.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llmtuner.extras.logging import get_logger
 | 
			
		||||
from llmtuner.tuner.core.trainer import PeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ComputeMetrics:
 | 
			
		||||
    r"""
 | 
			
		||||
    Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    tokenizer: PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
    def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Uses the model predictions to compute metrics.
 | 
			
		||||
        """
 | 
			
		||||
        preds, labels = eval_preds
 | 
			
		||||
        score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
 | 
			
		||||
 | 
			
		||||
        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
 | 
			
		||||
        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
 | 
			
		||||
 | 
			
		||||
        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
 | 
			
		||||
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
        for pred, label in zip(decoded_preds, decoded_labels):
 | 
			
		||||
            hypothesis = list(jieba.cut(pred))
 | 
			
		||||
            reference = list(jieba.cut(label))
 | 
			
		||||
 | 
			
		||||
            if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
 | 
			
		||||
                result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
 | 
			
		||||
            else:
 | 
			
		||||
                rouge = Rouge()
 | 
			
		||||
                scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
 | 
			
		||||
                result = scores[0]
 | 
			
		||||
 | 
			
		||||
            for k, v in result.items():
 | 
			
		||||
                score_dict[k].append(round(v["f"] * 100, 4))
 | 
			
		||||
 | 
			
		||||
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
 | 
			
		||||
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))
 | 
			
		||||
 | 
			
		||||
        return {k: float(np.mean(v)) for k, v in score_dict.items()}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Seq2SeqPeftTrainer(PeftTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
 | 
			
		||||
@ -80,7 +32,10 @@ class Seq2SeqPeftTrainer(PeftTrainer):
 | 
			
		||||
        Subclass and override to inject custom behavior.
 | 
			
		||||
        """
 | 
			
		||||
        prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
 | 
			
		||||
        inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
 | 
			
		||||
        if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs
 | 
			
		||||
            inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
            inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1)
 | 
			
		||||
        loss, generated_tokens, labels = super().prediction_step(
 | 
			
		||||
            model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
 | 
			
		||||
        )
 | 
			
		||||
@ -89,8 +44,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
 | 
			
		||||
        return (loss, generated_tokens, labels)
 | 
			
		||||
 | 
			
		||||
    def save_predictions(
 | 
			
		||||
            self,
 | 
			
		||||
            predict_results: PredictionOutput
 | 
			
		||||
        self,
 | 
			
		||||
        predict_results: PredictionOutput
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Saves model predictions to `output_dir`.
 | 
			
		||||
@ -1,31 +1,29 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Implements several parameter-efficient supervised fine-tuning method.
 | 
			
		||||
# This code is inspired by
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
 | 
			
		||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
 | 
			
		||||
 | 
			
		||||
from typing import Optional, List
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
 | 
			
		||||
 | 
			
		||||
from llmtuner.dsets import get_dataset, preprocess_dataset
 | 
			
		||||
from llmtuner.extras.callbacks import LogCallback
 | 
			
		||||
from llmtuner.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llmtuner.extras.misc import get_logits_processor
 | 
			
		||||
from llmtuner.extras.ploting import plot_loss
 | 
			
		||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
 | 
			
		||||
from llmtuner.tuner.core import load_model_and_tokenizer
 | 
			
		||||
from llmtuner.tuner.sft.metric import ComputeMetrics
 | 
			
		||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
from utils.other import IGNORE_INDEX
 | 
			
		||||
from utils import (
 | 
			
		||||
    Seq2SeqPeftTrainer,
 | 
			
		||||
    ComputeMetrics,
 | 
			
		||||
    LogCallback,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_args,
 | 
			
		||||
    prepare_data,
 | 
			
		||||
    preprocess_data,
 | 
			
		||||
    get_logits_processor,
 | 
			
		||||
    plot_loss
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    # Prepare pretrained model and dataset
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
 | 
			
		||||
    dataset = prepare_data(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
 | 
			
		||||
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
 | 
			
		||||
def run_sft(
 | 
			
		||||
    model_args: ModelArguments,
 | 
			
		||||
    data_args: DataArguments,
 | 
			
		||||
    training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
    finetuning_args: FinetuningArguments,
 | 
			
		||||
    callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
 | 
			
		||||
):
 | 
			
		||||
    dataset = get_dataset(model_args, data_args)
 | 
			
		||||
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
 | 
			
		||||
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
 | 
			
		||||
    data_collator = DataCollatorForSeq2Seq(
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
 | 
			
		||||
@ -54,7 +52,7 @@ def main():
 | 
			
		||||
        args=training_args,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        data_collator=data_collator,
 | 
			
		||||
        callbacks=[LogCallback()],
 | 
			
		||||
        callbacks=callbacks,
 | 
			
		||||
        compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
 | 
			
		||||
        **trainer_kwargs
 | 
			
		||||
    )
 | 
			
		||||
@ -94,12 +92,3 @@ def main():
 | 
			
		||||
        trainer.log_metrics("predict", predict_results.metrics)
 | 
			
		||||
        trainer.save_metrics("predict", predict_results.metrics)
 | 
			
		||||
        trainer.save_predictions(predict_results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mp_fn(index):
 | 
			
		||||
    # For xla_spawn (TPUs)
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										23
									
								
								src/train_bash.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								src/train_bash.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,23 @@
 | 
			
		||||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
 | 
			
		||||
 | 
			
		||||
    if general_args.stage == "pt":
 | 
			
		||||
        run_pt(model_args, data_args, training_args, finetuning_args)
 | 
			
		||||
    elif general_args.stage == "sft":
 | 
			
		||||
        run_sft(model_args, data_args, training_args, finetuning_args)
 | 
			
		||||
    elif general_args.stage == "rm":
 | 
			
		||||
        run_rm(model_args, data_args, training_args, finetuning_args)
 | 
			
		||||
    elif general_args.stage == "ppo":
 | 
			
		||||
        run_ppo(model_args, data_args, training_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _mp_fn(index):
 | 
			
		||||
    # For xla_spawn (TPUs)
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -1,17 +0,0 @@
 | 
			
		||||
from .common import (
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_args,
 | 
			
		||||
    prepare_infer_args,
 | 
			
		||||
    prepare_data,
 | 
			
		||||
    preprocess_data
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .peft_trainer import PeftTrainer, LogCallback
 | 
			
		||||
 | 
			
		||||
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
 | 
			
		||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
 | 
			
		||||
from .ppo import PPOPeftTrainer
 | 
			
		||||
 | 
			
		||||
from .template import Template
 | 
			
		||||
 | 
			
		||||
from .other import get_logits_processor, plot_loss
 | 
			
		||||
@ -1,619 +0,0 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import torch
 | 
			
		||||
import hashlib
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import List, Literal, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoModelForCausalLM,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
    HfArgumentParser,
 | 
			
		||||
    Seq2SeqTrainingArguments,
 | 
			
		||||
    BitsAndBytesConfig
 | 
			
		||||
)
 | 
			
		||||
from transformers.utils import check_min_version
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.tokenization_utils import PreTrainedTokenizer
 | 
			
		||||
 | 
			
		||||
import datasets
 | 
			
		||||
from datasets import Dataset, concatenate_datasets, load_dataset
 | 
			
		||||
 | 
			
		||||
from peft import (
 | 
			
		||||
    PeftModel,
 | 
			
		||||
    TaskType,
 | 
			
		||||
    LoraConfig,
 | 
			
		||||
    get_peft_model
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
 | 
			
		||||
 | 
			
		||||
from trl import AutoModelForCausalLMWithValueHead
 | 
			
		||||
 | 
			
		||||
from .config import (
 | 
			
		||||
    ModelArguments,
 | 
			
		||||
    DataTrainingArguments,
 | 
			
		||||
    FinetuningArguments,
 | 
			
		||||
    GeneratingArguments
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .template import Template
 | 
			
		||||
 | 
			
		||||
from .other import (
 | 
			
		||||
    get_logger,
 | 
			
		||||
    load_trainable_params,
 | 
			
		||||
    load_valuehead_params,
 | 
			
		||||
    print_trainable_params,
 | 
			
		||||
    prepare_model_for_training,
 | 
			
		||||
    IGNORE_INDEX
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
check_min_version("4.29.1")
 | 
			
		||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
 | 
			
		||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
 | 
			
		||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
 | 
			
		||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _init_adapter(
 | 
			
		||||
        model: PreTrainedModel,
 | 
			
		||||
        model_args: ModelArguments,
 | 
			
		||||
        finetuning_args: FinetuningArguments,
 | 
			
		||||
        is_trainable: bool,
 | 
			
		||||
        is_mergeable: bool
 | 
			
		||||
) -> PreTrainedModel:
 | 
			
		||||
    r"""
 | 
			
		||||
    Initializes the adapters.
 | 
			
		||||
 | 
			
		||||
    Support full-parameter, freeze and LoRA training.
 | 
			
		||||
 | 
			
		||||
    Note that the trainable parameters must be cast to float32.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "none" and is_trainable:
 | 
			
		||||
        raise ValueError("You cannot use finetuning_type=none while training.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "full":
 | 
			
		||||
        logger.info("Fine-tuning method: Full")
 | 
			
		||||
        model = model.float()
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "freeze":
 | 
			
		||||
        logger.info("Fine-tuning method: Freeze")
 | 
			
		||||
 | 
			
		||||
        for name, param in model.named_parameters():
 | 
			
		||||
            if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
 | 
			
		||||
                param.requires_grad_(False)
 | 
			
		||||
            else:
 | 
			
		||||
                param.data = param.data.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        if model_args.checkpoint_dir is not None:
 | 
			
		||||
            assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.finetuning_type == "lora":
 | 
			
		||||
        logger.info("Fine-tuning method: LoRA")
 | 
			
		||||
        latest_checkpoint = None
 | 
			
		||||
 | 
			
		||||
        if model_args.checkpoint_dir is not None:
 | 
			
		||||
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
 | 
			
		||||
                "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
 | 
			
		||||
            assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
 | 
			
		||||
                "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
 | 
			
		||||
 | 
			
		||||
            if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
 | 
			
		||||
                checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
 | 
			
		||||
            else:
 | 
			
		||||
                checkpoints_to_merge = model_args.checkpoint_dir
 | 
			
		||||
 | 
			
		||||
            for checkpoint in checkpoints_to_merge:
 | 
			
		||||
                model = PeftModel.from_pretrained(model, checkpoint)
 | 
			
		||||
                model = model.merge_and_unload()
 | 
			
		||||
 | 
			
		||||
            if len(checkpoints_to_merge) > 0:
 | 
			
		||||
                logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
 | 
			
		||||
 | 
			
		||||
            if latest_checkpoint is not None: # resume lora training or quantized inference
 | 
			
		||||
                model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
 | 
			
		||||
 | 
			
		||||
        if is_trainable and latest_checkpoint is None: # create new lora weights while training
 | 
			
		||||
            lora_config = LoraConfig(
 | 
			
		||||
                task_type=TaskType.CAUSAL_LM,
 | 
			
		||||
                inference_mode=False,
 | 
			
		||||
                r=finetuning_args.lora_rank,
 | 
			
		||||
                lora_alpha=finetuning_args.lora_alpha,
 | 
			
		||||
                lora_dropout=finetuning_args.lora_dropout,
 | 
			
		||||
                target_modules=finetuning_args.lora_target
 | 
			
		||||
            )
 | 
			
		||||
            model = get_peft_model(model, lora_config)
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_pretrained(
 | 
			
		||||
        model_args: ModelArguments,
 | 
			
		||||
        finetuning_args: FinetuningArguments,
 | 
			
		||||
        is_trainable: Optional[bool] = False,
 | 
			
		||||
        stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
 | 
			
		||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads pretrained model and tokenizer.
 | 
			
		||||
 | 
			
		||||
    Support both training and inference.
 | 
			
		||||
    """
 | 
			
		||||
    if (not is_trainable) and model_args.checkpoint_dir is None:
 | 
			
		||||
        logger.warning("Checkpoint is not found at evaluation, load the original model.")
 | 
			
		||||
        finetuning_args = FinetuningArguments(finetuning_type="none")
 | 
			
		||||
 | 
			
		||||
    assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "RM and PPO training can only be performed with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    config_kwargs = {
 | 
			
		||||
        "trust_remote_code": True,
 | 
			
		||||
        "cache_dir": model_args.cache_dir,
 | 
			
		||||
        "revision": model_args.model_revision,
 | 
			
		||||
        "use_auth_token": True if model_args.use_auth_token else None,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
        model_args.model_name_or_path,
 | 
			
		||||
        use_fast=model_args.use_fast_tokenizer,
 | 
			
		||||
        padding_side=model_args.padding_side,
 | 
			
		||||
        **config_kwargs
 | 
			
		||||
    )
 | 
			
		||||
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
 | 
			
		||||
        tokenizer.pad_token_id = 0 # set as the <unk> token
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
 | 
			
		||||
    is_mergeable = True
 | 
			
		||||
 | 
			
		||||
    # Quantization configurations (using bitsandbytes library).
 | 
			
		||||
    if model_args.quantization_bit is not None:
 | 
			
		||||
        if model_args.quantization_bit == 8:
 | 
			
		||||
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
 | 
			
		||||
            config_kwargs["load_in_8bit"] = True
 | 
			
		||||
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                load_in_8bit=True,
 | 
			
		||||
                llm_int8_threshold=6.0
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        elif model_args.quantization_bit == 4:
 | 
			
		||||
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
 | 
			
		||||
            require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
 | 
			
		||||
            require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
 | 
			
		||||
            require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
 | 
			
		||||
            config_kwargs["load_in_4bit"] = True
 | 
			
		||||
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
 | 
			
		||||
                load_in_4bit=True,
 | 
			
		||||
                bnb_4bit_compute_dtype=model_args.compute_dtype,
 | 
			
		||||
                bnb_4bit_use_double_quant=model_args.double_quantization,
 | 
			
		||||
                bnb_4bit_quant_type=model_args.quantization_type
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        is_mergeable = False
 | 
			
		||||
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
 | 
			
		||||
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
 | 
			
		||||
 | 
			
		||||
    if not is_trainable: # `device_map=auto` should be used for inference only
 | 
			
		||||
        config_kwargs["device_map"] = "auto"
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
 | 
			
		||||
        model_to_load = model_args.checkpoint_dir[0]
 | 
			
		||||
    else:
 | 
			
		||||
        model_to_load = model_args.model_name_or_path
 | 
			
		||||
 | 
			
		||||
    # Load and prepare pretrained models (without valuehead).
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_to_load,
 | 
			
		||||
        config=config,
 | 
			
		||||
        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
        **config_kwargs
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Register auto class to save the custom code files.
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
 | 
			
		||||
        config.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
 | 
			
		||||
        tokenizer.__class__.register_for_auto_class()
 | 
			
		||||
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
 | 
			
		||||
        model.__class__.register_for_auto_class()
 | 
			
		||||
 | 
			
		||||
    # Initialize adapters
 | 
			
		||||
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
 | 
			
		||||
    model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
 | 
			
		||||
 | 
			
		||||
    if stage == "rm" or stage == "ppo": # add value head
 | 
			
		||||
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
 | 
			
		||||
 | 
			
		||||
        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
 | 
			
		||||
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
 | 
			
		||||
            if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
 | 
			
		||||
                model.v_head.load_state_dict({
 | 
			
		||||
                    "summary.weight": getattr(model, "reward_head_weight"),
 | 
			
		||||
                    "summary.bias": getattr(model, "reward_head_bias")
 | 
			
		||||
                })
 | 
			
		||||
 | 
			
		||||
        if stage == "ppo": # load reward model
 | 
			
		||||
            assert is_trainable, "PPO stage cannot be performed at evaluation."
 | 
			
		||||
            assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
 | 
			
		||||
            logger.info("Load reward model from {}".format(model_args.reward_model))
 | 
			
		||||
            model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
 | 
			
		||||
            assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
 | 
			
		||||
 | 
			
		||||
    if not is_trainable:
 | 
			
		||||
        model.requires_grad_(False) # fix all model params
 | 
			
		||||
        model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
 | 
			
		||||
 | 
			
		||||
    print_trainable_params(model)
 | 
			
		||||
 | 
			
		||||
    return model, tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_args(
 | 
			
		||||
        stage: Literal["pt", "sft", "rm", "ppo"]
 | 
			
		||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
 | 
			
		||||
 | 
			
		||||
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
 | 
			
		||||
 | 
			
		||||
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
 | 
			
		||||
 | 
			
		||||
    # Setup logging
 | 
			
		||||
    if training_args.should_log:
 | 
			
		||||
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
 | 
			
		||||
        transformers.utils.logging.set_verbosity_info()
 | 
			
		||||
 | 
			
		||||
    log_level = training_args.get_process_log_level()
 | 
			
		||||
    datasets.utils.logging.set_verbosity(log_level)
 | 
			
		||||
    transformers.utils.logging.set_verbosity(log_level)
 | 
			
		||||
    transformers.utils.logging.enable_default_handler()
 | 
			
		||||
    transformers.utils.logging.enable_explicit_format()
 | 
			
		||||
 | 
			
		||||
    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
 | 
			
		||||
    data_args.init_for_training()
 | 
			
		||||
 | 
			
		||||
    assert stage == "sft" or (not training_args.predict_with_generate), \
 | 
			
		||||
        "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
 | 
			
		||||
 | 
			
		||||
    assert not (training_args.do_train and training_args.predict_with_generate), \
 | 
			
		||||
        "`predict_with_generate` cannot be set as True while training."
 | 
			
		||||
 | 
			
		||||
    assert (not training_args.do_predict) or training_args.predict_with_generate, \
 | 
			
		||||
        "Please enable `predict_with_generate` to save model predictions."
 | 
			
		||||
 | 
			
		||||
    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "Quantization is only compatible with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
 | 
			
		||||
        else:
 | 
			
		||||
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
 | 
			
		||||
                "Quantized model only accepts a single checkpoint."
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None and (not training_args.do_train):
 | 
			
		||||
        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
 | 
			
		||||
 | 
			
		||||
    if training_args.do_train and (not training_args.fp16):
 | 
			
		||||
        logger.warning("We recommend enable fp16 mixed precision training.")
 | 
			
		||||
 | 
			
		||||
    if data_args.prompt_template == "default":
 | 
			
		||||
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
 | 
			
		||||
 | 
			
		||||
    if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
 | 
			
		||||
        logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
 | 
			
		||||
        training_args.ddp_find_unused_parameters = False
 | 
			
		||||
 | 
			
		||||
    training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
 | 
			
		||||
 | 
			
		||||
    if model_args.quantization_bit is not None:
 | 
			
		||||
        if training_args.fp16:
 | 
			
		||||
            model_args.compute_dtype = torch.float16
 | 
			
		||||
        elif training_args.bf16:
 | 
			
		||||
            model_args.compute_dtype = torch.bfloat16
 | 
			
		||||
        else:
 | 
			
		||||
            model_args.compute_dtype = torch.float32
 | 
			
		||||
 | 
			
		||||
    # Log on each process the small summary:
 | 
			
		||||
    logger.info(
 | 
			
		||||
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
 | 
			
		||||
        + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
 | 
			
		||||
    )
 | 
			
		||||
    logger.info(f"Training/evaluation parameters {training_args}")
 | 
			
		||||
 | 
			
		||||
    # Set seed before initializing model.
 | 
			
		||||
    transformers.set_seed(training_args.seed)
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, training_args, finetuning_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
 | 
			
		||||
 | 
			
		||||
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
 | 
			
		||||
 | 
			
		||||
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
 | 
			
		||||
 | 
			
		||||
    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
 | 
			
		||||
        "Quantization is only compatible with the LoRA method."
 | 
			
		||||
 | 
			
		||||
    if model_args.checkpoint_dir is not None:
 | 
			
		||||
        if finetuning_args.finetuning_type != "lora":
 | 
			
		||||
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
 | 
			
		||||
        else:
 | 
			
		||||
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
 | 
			
		||||
                "Quantized model only accepts a single checkpoint."
 | 
			
		||||
 | 
			
		||||
    if data_args.prompt_template == "default":
 | 
			
		||||
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
 | 
			
		||||
 | 
			
		||||
    return model_args, data_args, finetuning_args, generating_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def prepare_data(
 | 
			
		||||
        model_args: ModelArguments,
 | 
			
		||||
        data_args: DataTrainingArguments
 | 
			
		||||
) -> Dataset:
 | 
			
		||||
 | 
			
		||||
    def checksum(file_path, hash):
 | 
			
		||||
        with open(file_path, "rb") as datafile:
 | 
			
		||||
            binary_data = datafile.read()
 | 
			
		||||
        sha1 = hashlib.sha1(binary_data).hexdigest()
 | 
			
		||||
        if sha1 != hash:
 | 
			
		||||
            logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
 | 
			
		||||
 | 
			
		||||
    ext2type = {
 | 
			
		||||
        "csv": "csv",
 | 
			
		||||
        "json": "json",
 | 
			
		||||
        "jsonl": "json",
 | 
			
		||||
        "txt": "text"
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    max_samples = data_args.max_samples
 | 
			
		||||
    all_datasets: List[Dataset] = [] # support multiple datasets
 | 
			
		||||
 | 
			
		||||
    for dataset_attr in data_args.dataset_list:
 | 
			
		||||
 | 
			
		||||
        logger.info("Loading dataset {}...".format(dataset_attr))
 | 
			
		||||
 | 
			
		||||
        if dataset_attr.load_from == "hf_hub":
 | 
			
		||||
            data_path = dataset_attr.dataset_name
 | 
			
		||||
            data_files = None
 | 
			
		||||
        elif dataset_attr.load_from == "script":
 | 
			
		||||
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
 | 
			
		||||
            data_files = None
 | 
			
		||||
        elif dataset_attr.load_from == "file":
 | 
			
		||||
            data_path = None
 | 
			
		||||
            data_files: List[str] = []
 | 
			
		||||
 | 
			
		||||
            if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                    data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
 | 
			
		||||
 | 
			
		||||
                    if data_path is None:
 | 
			
		||||
                        data_path = ext2type.get(data_files[0].split(".")[-1], None)
 | 
			
		||||
                    else:
 | 
			
		||||
                        assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
 | 
			
		||||
            elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
 | 
			
		||||
                data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
 | 
			
		||||
                data_path = ext2type.get(data_files[0].split(".")[-1], None)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError("File not found.")
 | 
			
		||||
 | 
			
		||||
            assert data_path, "File extension must be txt, csv, json or jsonl."
 | 
			
		||||
 | 
			
		||||
            if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
 | 
			
		||||
                checksum(data_files[0], dataset_attr.dataset_sha1)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        raw_datasets = load_dataset(
 | 
			
		||||
            data_path,
 | 
			
		||||
            data_files=data_files,
 | 
			
		||||
            cache_dir=model_args.cache_dir,
 | 
			
		||||
            use_auth_token=True if model_args.use_auth_token else None
 | 
			
		||||
        )
 | 
			
		||||
        dataset = raw_datasets[data_args.split]
 | 
			
		||||
 | 
			
		||||
        if max_samples is not None:
 | 
			
		||||
            max_samples_temp = min(len(dataset), max_samples)
 | 
			
		||||
            dataset = dataset.select(range(max_samples_temp))
 | 
			
		||||
 | 
			
		||||
        dummy_data = [None] * len(dataset)
 | 
			
		||||
        prefix_data = [dataset_attr.source_prefix] * len(dataset)
 | 
			
		||||
        for column_name, target_name in [
 | 
			
		||||
            ("prompt_column", "prompt"),
 | 
			
		||||
            ("query_column", "query"),
 | 
			
		||||
            ("response_column", "response"),
 | 
			
		||||
            ("history_column", "history")
 | 
			
		||||
        ]: # every dataset will have 4 columns same as each other
 | 
			
		||||
            if getattr(dataset_attr, column_name) != target_name:
 | 
			
		||||
                if getattr(dataset_attr, column_name):
 | 
			
		||||
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
 | 
			
		||||
                else: # None or empty string
 | 
			
		||||
                    dataset = dataset.add_column(target_name, dummy_data)
 | 
			
		||||
        dataset = dataset.add_column("prefix", prefix_data)
 | 
			
		||||
        all_datasets.append(dataset)
 | 
			
		||||
 | 
			
		||||
    if len(data_args.dataset_list) == 1:
 | 
			
		||||
        all_datasets = all_datasets[0]
 | 
			
		||||
    else:
 | 
			
		||||
        all_datasets = concatenate_datasets(all_datasets)
 | 
			
		||||
 | 
			
		||||
    return all_datasets
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_data(
 | 
			
		||||
        dataset: Dataset,
 | 
			
		||||
        tokenizer: PreTrainedTokenizer,
 | 
			
		||||
        data_args: DataTrainingArguments,
 | 
			
		||||
        training_args: Seq2SeqTrainingArguments,
 | 
			
		||||
        stage: Literal["pt", "sft", "rm", "ppo"]
 | 
			
		||||
) -> Dataset:
 | 
			
		||||
 | 
			
		||||
    column_names = list(dataset.column_names)
 | 
			
		||||
    prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
 | 
			
		||||
    # support question with a single answer or multiple answers
 | 
			
		||||
    def get_dialog(examples):
 | 
			
		||||
        for i in range(len(examples["prompt"])):
 | 
			
		||||
            if examples["prompt"][i] and examples["response"][i]:
 | 
			
		||||
                query, answer = examples["prompt"][i], examples["response"][i]
 | 
			
		||||
                query = query + "\n" + examples["query"][i] if examples["query"][i] else query
 | 
			
		||||
                prefix = examples["prefix"][i] if examples["prefix"][i] else ""
 | 
			
		||||
                dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
 | 
			
		||||
                yield dialog
 | 
			
		||||
 | 
			
		||||
    def preprocess_pretrain_dataset(examples):
 | 
			
		||||
        # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
 | 
			
		||||
        text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
 | 
			
		||||
        concatenated_ids = list(chain(*text_ids))
 | 
			
		||||
        total_length = len(concatenated_ids)
 | 
			
		||||
        block_size = data_args.max_source_length - 1
 | 
			
		||||
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
 | 
			
		||||
        total_length = (total_length // block_size) * block_size
 | 
			
		||||
        # split by chunks of max_source_length
 | 
			
		||||
        result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
 | 
			
		||||
                  for i in range(0, total_length, block_size)]
 | 
			
		||||
        return {
 | 
			
		||||
            "input_ids": result,
 | 
			
		||||
            "labels": result.copy()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def preprocess_supervised_dataset(examples):
 | 
			
		||||
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
 | 
			
		||||
        # for input with history, we build multiple input-label pairs just like:
 | 
			
		||||
        # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
 | 
			
		||||
        model_inputs = {"input_ids": [], "labels": []}
 | 
			
		||||
        max_length = data_args.max_source_length + data_args.max_target_length
 | 
			
		||||
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            input_ids, labels = [], []
 | 
			
		||||
 | 
			
		||||
            for i in range(len(dialog) // 2):
 | 
			
		||||
                source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
 | 
			
		||||
                target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
                if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                    source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
                if len(target_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                    target_ids = target_ids[:data_args.max_target_length - 1]
 | 
			
		||||
 | 
			
		||||
                if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
 | 
			
		||||
                labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
            model_inputs["input_ids"].append(input_ids)
 | 
			
		||||
            model_inputs["labels"].append(labels)
 | 
			
		||||
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def preprocess_unsupervised_dataset(examples):
 | 
			
		||||
        # build inputs with format `<bos> X` and labels with format `<bos> Y`
 | 
			
		||||
        model_inputs = {"input_ids": [], "labels": []}
 | 
			
		||||
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            prompt, answer = "".join(dialog[:-1]), dialog[-1]
 | 
			
		||||
 | 
			
		||||
            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 | 
			
		||||
            target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
            if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
            if len(target_ids) > data_args.max_target_length:
 | 
			
		||||
                target_ids = target_ids[:data_args.max_target_length]
 | 
			
		||||
 | 
			
		||||
            model_inputs["input_ids"].append(source_ids)
 | 
			
		||||
            model_inputs["labels"].append(target_ids)
 | 
			
		||||
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def preprocess_pairwise_dataset(examples):
 | 
			
		||||
        # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
 | 
			
		||||
        model_inputs = {"accept_ids": [], "reject_ids": []}
 | 
			
		||||
        for dialog in get_dialog(examples):
 | 
			
		||||
            prompt, answer = "".join(dialog[:-1]), dialog[-1]
 | 
			
		||||
 | 
			
		||||
            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
 | 
			
		||||
            accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
 | 
			
		||||
            reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
            if len(source_ids) > data_args.max_source_length:
 | 
			
		||||
                source_ids = source_ids[:data_args.max_source_length]
 | 
			
		||||
            if len(accept_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                accept_ids = accept_ids[:data_args.max_target_length - 1]
 | 
			
		||||
            if len(reject_ids) > data_args.max_target_length - 1: # eos token
 | 
			
		||||
                reject_ids = reject_ids[:data_args.max_target_length - 1]
 | 
			
		||||
 | 
			
		||||
            accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
 | 
			
		||||
            reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
 | 
			
		||||
 | 
			
		||||
            model_inputs["accept_ids"].append(accept_ids)
 | 
			
		||||
            model_inputs["reject_ids"].append(reject_ids)
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    def print_supervised_dataset_example(example):
 | 
			
		||||
        print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
        print("label_ids:\n{}".format(example["labels"]))
 | 
			
		||||
        print("labels:\n{}".format(
 | 
			
		||||
            tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
 | 
			
		||||
                             skip_special_tokens=False)
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    def print_pairwise_dataset_example(example):
 | 
			
		||||
        print("accept_ids:\n{}".format(example["accept_ids"]))
 | 
			
		||||
        print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
 | 
			
		||||
        print("reject_ids:\n{}".format(example["reject_ids"]))
 | 
			
		||||
        print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
    def print_unsupervised_dataset_example(example):
 | 
			
		||||
        print("input_ids:\n{}".format(example["input_ids"]))
 | 
			
		||||
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
 | 
			
		||||
 | 
			
		||||
    if stage == "pt":
 | 
			
		||||
        preprocess_function = preprocess_pretrain_dataset
 | 
			
		||||
    elif stage == "sft":
 | 
			
		||||
        preprocess_function = preprocess_unsupervised_dataset \
 | 
			
		||||
            if training_args.predict_with_generate else preprocess_supervised_dataset
 | 
			
		||||
    elif stage == "rm":
 | 
			
		||||
        preprocess_function = preprocess_pairwise_dataset
 | 
			
		||||
    elif stage == "ppo":
 | 
			
		||||
        preprocess_function = preprocess_unsupervised_dataset
 | 
			
		||||
 | 
			
		||||
    with training_args.main_process_first(desc="dataset map pre-processing"):
 | 
			
		||||
        dataset = dataset.map(
 | 
			
		||||
            preprocess_function,
 | 
			
		||||
            batched=True,
 | 
			
		||||
            num_proc=data_args.preprocessing_num_workers,
 | 
			
		||||
            remove_columns=column_names,
 | 
			
		||||
            load_from_cache_file=not data_args.overwrite_cache,
 | 
			
		||||
            desc="Running tokenizer on dataset"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if stage == "pt":
 | 
			
		||||
            print_unsupervised_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "sft":
 | 
			
		||||
            print_supervised_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "rm":
 | 
			
		||||
            print_pairwise_dataset_example(dataset[0])
 | 
			
		||||
        elif stage == "ppo":
 | 
			
		||||
            print_unsupervised_dataset_example(dataset[0])
 | 
			
		||||
 | 
			
		||||
        return dataset
 | 
			
		||||
@ -1,312 +0,0 @@
 | 
			
		||||
import os
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Any, Dict, List, Literal, Optional
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DatasetAttr:
 | 
			
		||||
 | 
			
		||||
    load_from: str
 | 
			
		||||
    dataset_name: Optional[str] = None
 | 
			
		||||
    dataset_sha1: Optional[str] = None
 | 
			
		||||
    source_prefix: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self.dataset_name
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.prompt_column = "instruction"
 | 
			
		||||
        self.query_column = "input"
 | 
			
		||||
        self.response_column = "output"
 | 
			
		||||
        self.history_column = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ModelArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
 | 
			
		||||
    """
 | 
			
		||||
    model_name_or_path: str = field(
 | 
			
		||||
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
 | 
			
		||||
    )
 | 
			
		||||
    cache_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
 | 
			
		||||
    )
 | 
			
		||||
    use_fast_tokenizer: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
 | 
			
		||||
    )
 | 
			
		||||
    use_auth_token: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
 | 
			
		||||
    )
 | 
			
		||||
    model_revision: Optional[str] = field(
 | 
			
		||||
        default="main",
 | 
			
		||||
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
 | 
			
		||||
    )
 | 
			
		||||
    padding_side: Optional[Literal["left", "right"]] = field(
 | 
			
		||||
        default="left",
 | 
			
		||||
        metadata={"help": "The side on which the model should have padding applied."}
 | 
			
		||||
    )
 | 
			
		||||
    quantization_bit: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of bits to quantize the model."}
 | 
			
		||||
    )
 | 
			
		||||
    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
 | 
			
		||||
        default="nf4",
 | 
			
		||||
        metadata={"help": "Quantization data type to use in int4 training."}
 | 
			
		||||
    )
 | 
			
		||||
    double_quantization: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to use double quantization in int4 training or not."}
 | 
			
		||||
    )
 | 
			
		||||
    compute_dtype: Optional[torch.dtype] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
 | 
			
		||||
    )
 | 
			
		||||
    checkpoint_dir: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
 | 
			
		||||
    )
 | 
			
		||||
    reward_model: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
 | 
			
		||||
    )
 | 
			
		||||
    resume_lora_training: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
 | 
			
		||||
    )
 | 
			
		||||
    plot_loss: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if self.checkpoint_dir is not None: # support merging multiple lora weights
 | 
			
		||||
            self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.quantization_bit is not None:
 | 
			
		||||
            assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataTrainingArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to what data we are going to input our model for training and evaluation.
 | 
			
		||||
    """
 | 
			
		||||
    dataset: Optional[str] = field(
 | 
			
		||||
        default="alpaca_zh",
 | 
			
		||||
        metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
 | 
			
		||||
    )
 | 
			
		||||
    dataset_dir: Optional[str] = field(
 | 
			
		||||
        default="data",
 | 
			
		||||
        metadata={"help": "The name of the folder containing datasets."}
 | 
			
		||||
    )
 | 
			
		||||
    split: Optional[str] = field(
 | 
			
		||||
        default="train",
 | 
			
		||||
        metadata={"help": "Which dataset split to use for training and evaluation."}
 | 
			
		||||
    )
 | 
			
		||||
    overwrite_cache: Optional[bool] = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Overwrite the cached training and evaluation sets."}
 | 
			
		||||
    )
 | 
			
		||||
    preprocessing_num_workers: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The number of processes to use for the preprocessing."}
 | 
			
		||||
    )
 | 
			
		||||
    max_source_length: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum total input sequence length after tokenization."}
 | 
			
		||||
    )
 | 
			
		||||
    max_target_length: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum total output sequence length after tokenization."}
 | 
			
		||||
    )
 | 
			
		||||
    max_samples: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
 | 
			
		||||
    )
 | 
			
		||||
    eval_num_beams: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
 | 
			
		||||
    )
 | 
			
		||||
    ignore_pad_token_for_loss: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
 | 
			
		||||
    )
 | 
			
		||||
    source_prefix: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
 | 
			
		||||
    )
 | 
			
		||||
    dev_ratio: Optional[float] = field(
 | 
			
		||||
        default=0,
 | 
			
		||||
        metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
 | 
			
		||||
    )
 | 
			
		||||
    prompt_template: Optional[str] = field(
 | 
			
		||||
        default="default",
 | 
			
		||||
        metadata={"help": "Which template to use for constructing prompts in training and inference."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def init_for_training(self): # support mixing multiple datasets
 | 
			
		||||
        dataset_names = [ds.strip() for ds in self.dataset.split(",")]
 | 
			
		||||
        with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
 | 
			
		||||
            dataset_info = json.load(f)
 | 
			
		||||
 | 
			
		||||
        if self.source_prefix is not None:
 | 
			
		||||
            prefix_list = self.source_prefix.split("|")
 | 
			
		||||
            prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
 | 
			
		||||
            assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
 | 
			
		||||
        else:
 | 
			
		||||
            prefix_list = [None] * len(dataset_names)
 | 
			
		||||
 | 
			
		||||
        self.dataset_list: List[DatasetAttr] = []
 | 
			
		||||
        for i, name in enumerate(dataset_names):
 | 
			
		||||
            if name not in dataset_info:
 | 
			
		||||
                raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
 | 
			
		||||
 | 
			
		||||
            if "hf_hub_url" in dataset_info[name]:
 | 
			
		||||
                dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
 | 
			
		||||
            elif "script_url" in dataset_info[name]:
 | 
			
		||||
                dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
 | 
			
		||||
            else:
 | 
			
		||||
                dataset_attr = DatasetAttr(
 | 
			
		||||
                    "file",
 | 
			
		||||
                    dataset_name=dataset_info[name]["file_name"],
 | 
			
		||||
                    dataset_sha1=dataset_info[name].get("file_sha1", None)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            dataset_attr.source_prefix = prefix_list[i]
 | 
			
		||||
 | 
			
		||||
            if "columns" in dataset_info[name]:
 | 
			
		||||
                dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
 | 
			
		||||
                dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
 | 
			
		||||
                dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
 | 
			
		||||
                dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
 | 
			
		||||
 | 
			
		||||
            self.dataset_list.append(dataset_attr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FinetuningArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to which techniques we are going to fine-tuning with.
 | 
			
		||||
    """
 | 
			
		||||
    finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
 | 
			
		||||
        default="lora",
 | 
			
		||||
        metadata={"help": "Which fine-tuning method to use."}
 | 
			
		||||
    )
 | 
			
		||||
    num_hidden_layers: Optional[int] = field(
 | 
			
		||||
        default=32,
 | 
			
		||||
        metadata={"help": "Number of decoder blocks in the model. \
 | 
			
		||||
                  LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
 | 
			
		||||
                  BLOOM choices: [\"24\", \"30\", \"70\"], \
 | 
			
		||||
                  Falcon choices: [\"32\", \"60\"], \
 | 
			
		||||
                  Baichuan choices: [\"32\"]"}
 | 
			
		||||
    )
 | 
			
		||||
    num_layer_trainable: Optional[int] = field(
 | 
			
		||||
        default=3,
 | 
			
		||||
        metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
 | 
			
		||||
        default="mlp",
 | 
			
		||||
        metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
 | 
			
		||||
                  LLaMA choices: [\"mlp\", \"self_attn\"], \
 | 
			
		||||
                  BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
 | 
			
		||||
                  Baichuan choices: [\"mlp\", \"self_attn\"]"}
 | 
			
		||||
    )
 | 
			
		||||
    lora_rank: Optional[int] = field(
 | 
			
		||||
        default=8,
 | 
			
		||||
        metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_alpha: Optional[float] = field(
 | 
			
		||||
        default=32.0,
 | 
			
		||||
        metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_dropout: Optional[float] = field(
 | 
			
		||||
        default=0.1,
 | 
			
		||||
        metadata={"help": "Dropout rate for the LoRA fine-tuning."}
 | 
			
		||||
    )
 | 
			
		||||
    lora_target: Optional[str] = field(
 | 
			
		||||
        default="q_proj,v_proj",
 | 
			
		||||
        metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
 | 
			
		||||
                  LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
 | 
			
		||||
                  BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
 | 
			
		||||
                  Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
 | 
			
		||||
            self.lora_target = [target.strip() for target in self.lora_target.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
 | 
			
		||||
            trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
 | 
			
		||||
        else: # fine-tuning the first n layers if num_layer_trainable < 0
 | 
			
		||||
            trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
 | 
			
		||||
 | 
			
		||||
        self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
 | 
			
		||||
 | 
			
		||||
        assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
 | 
			
		||||
 | 
			
		||||
    def save_to_json(self, json_path: str):
 | 
			
		||||
        """Saves the content of this instance in JSON format inside `json_path`."""
 | 
			
		||||
        json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
 | 
			
		||||
        with open(json_path, "w", encoding="utf-8") as f:
 | 
			
		||||
            f.write(json_string)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def load_from_json(cls, json_path: str):
 | 
			
		||||
        """Creates an instance from the content of `json_path`."""
 | 
			
		||||
        with open(json_path, "r", encoding="utf-8") as f:
 | 
			
		||||
            text = f.read()
 | 
			
		||||
        return cls(**json.loads(text))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class GeneratingArguments:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments pertaining to specify the decoding parameters.
 | 
			
		||||
    """
 | 
			
		||||
    do_sample: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
 | 
			
		||||
    )
 | 
			
		||||
    temperature: Optional[float] = field(
 | 
			
		||||
        default=0.95,
 | 
			
		||||
        metadata={"help": "The value used to modulate the next token probabilities."}
 | 
			
		||||
    )
 | 
			
		||||
    top_p: Optional[float] = field(
 | 
			
		||||
        default=0.7,
 | 
			
		||||
        metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
 | 
			
		||||
    )
 | 
			
		||||
    top_k: Optional[int] = field(
 | 
			
		||||
        default=50,
 | 
			
		||||
        metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
 | 
			
		||||
    )
 | 
			
		||||
    num_beams: Optional[int] = field(
 | 
			
		||||
        default=1,
 | 
			
		||||
        metadata={"help": "Number of beams for beam search. 1 means no beam search."}
 | 
			
		||||
    )
 | 
			
		||||
    max_length: Optional[int] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
 | 
			
		||||
    )
 | 
			
		||||
    max_new_tokens: Optional[int] = field(
 | 
			
		||||
        default=512,
 | 
			
		||||
        metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
 | 
			
		||||
    )
 | 
			
		||||
    repetition_penalty: Optional[float] = field(
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
 | 
			
		||||
    )
 | 
			
		||||
    length_penalty: Optional[float] = field(
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> Dict[str, Any]:
 | 
			
		||||
        args = asdict(self)
 | 
			
		||||
        if args.get("max_new_tokens", None):
 | 
			
		||||
            args.pop("max_length", None)
 | 
			
		||||
        return args
 | 
			
		||||
@ -1,197 +0,0 @@
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import json
 | 
			
		||||
import torch
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers.trainer import TRAINER_STATE_NAME, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
 | 
			
		||||
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint
 | 
			
		||||
from transformers.generation.utils import LogitsProcessorList
 | 
			
		||||
from transformers.generation.logits_process import LogitsProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
IGNORE_INDEX = -100
 | 
			
		||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
 | 
			
		||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: str) -> logging.Logger:
 | 
			
		||||
    return logging.getLogger(name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logging.basicConfig(
 | 
			
		||||
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 | 
			
		||||
    datefmt="%m/%d/%Y %H:%M:%S",
 | 
			
		||||
    level=logging.INFO,
 | 
			
		||||
    handlers=[logging.StreamHandler(sys.stdout)]
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AverageMeter:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes and stores the average and current value.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.reset()
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
        self.val = 0
 | 
			
		||||
        self.avg = 0
 | 
			
		||||
        self.sum = 0
 | 
			
		||||
        self.count = 0
 | 
			
		||||
 | 
			
		||||
    def update(self, val, n=1):
 | 
			
		||||
        self.val = val
 | 
			
		||||
        self.sum += val * n
 | 
			
		||||
        self.count += n
 | 
			
		||||
        self.avg = self.sum / self.count
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Avoid runtime error in model.generate(do_sample=True).
 | 
			
		||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
 | 
			
		||||
 | 
			
		||||
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
			
		||||
        if torch.isnan(scores).any() or torch.isinf(scores).any():
 | 
			
		||||
            scores.zero_()
 | 
			
		||||
            scores[..., 0] = 1.0
 | 
			
		||||
        return scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logits_processor() -> LogitsProcessorList:
 | 
			
		||||
    logits_processor = LogitsProcessorList()
 | 
			
		||||
    logits_processor.append(InvalidScoreLogitsProcessor())
 | 
			
		||||
    return logits_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
 | 
			
		||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
 | 
			
		||||
def prepare_model_for_training(
 | 
			
		||||
        model: PreTrainedModel,
 | 
			
		||||
        finetuning_type: str,
 | 
			
		||||
        output_embedding_layer_name: Optional[str] = "lm_head",
 | 
			
		||||
        use_gradient_checkpointing: Optional[bool] = True,
 | 
			
		||||
        layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
 | 
			
		||||
) -> PreTrainedModel:
 | 
			
		||||
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
 | 
			
		||||
            param.data = param.data.to(torch.float32)
 | 
			
		||||
 | 
			
		||||
    if use_gradient_checkpointing:
 | 
			
		||||
        if hasattr(model, "enable_input_require_grads"):
 | 
			
		||||
            model.enable_input_require_grads()
 | 
			
		||||
        else:
 | 
			
		||||
            def make_inputs_require_grad(module, input, output):
 | 
			
		||||
                output.requires_grad_(True)
 | 
			
		||||
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
 | 
			
		||||
 | 
			
		||||
        model.gradient_checkpointing_enable()
 | 
			
		||||
        model.config.use_cache = False # turn off when gradient checkpointing is enabled
 | 
			
		||||
 | 
			
		||||
    if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
 | 
			
		||||
        output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
 | 
			
		||||
        input_dtype = output_embedding_layer.weight.dtype
 | 
			
		||||
 | 
			
		||||
        class CastOutputToFloat(torch.nn.Sequential):
 | 
			
		||||
 | 
			
		||||
            def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
                return super().forward(x.to(input_dtype)).to(torch.float32)
 | 
			
		||||
 | 
			
		||||
        setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
 | 
			
		||||
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_trainable_params(model: torch.nn.Module) -> None:
 | 
			
		||||
    trainable_params, all_param = 0, 0
 | 
			
		||||
    for param in model.parameters():
 | 
			
		||||
        num_params = param.numel()
 | 
			
		||||
        # if using DS Zero 3 and the weights are initialized empty
 | 
			
		||||
        if num_params == 0 and hasattr(param, "ds_numel"):
 | 
			
		||||
            num_params = param.ds_numel
 | 
			
		||||
        all_param += num_params
 | 
			
		||||
        if param.requires_grad:
 | 
			
		||||
            trainable_params += num_params
 | 
			
		||||
    print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
 | 
			
		||||
                trainable_params, all_param, 100 * trainable_params / all_param))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
 | 
			
		||||
    state_dict = model.state_dict()
 | 
			
		||||
    filtered_state_dict = {}
 | 
			
		||||
 | 
			
		||||
    for k, v in model.named_parameters():
 | 
			
		||||
        if v.requires_grad:
 | 
			
		||||
            filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
 | 
			
		||||
 | 
			
		||||
    return filtered_state_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
 | 
			
		||||
    weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
 | 
			
		||||
    if os.path.exists(weights_file):
 | 
			
		||||
        model_state_dict = torch.load(weights_file, map_location="cpu")
 | 
			
		||||
        model.load_state_dict(model_state_dict, strict=False) # skip missing keys
 | 
			
		||||
    elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
 | 
			
		||||
        load_sharded_checkpoint(model, checkpoint_dir, strict=False)
 | 
			
		||||
    else:
 | 
			
		||||
        logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
 | 
			
		||||
    valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
 | 
			
		||||
    if not os.path.exists(valuehead_file):
 | 
			
		||||
        logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
 | 
			
		||||
        return False
 | 
			
		||||
    valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
 | 
			
		||||
    model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
 | 
			
		||||
    model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
 | 
			
		||||
    model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
 | 
			
		||||
    model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
 | 
			
		||||
    r"""
 | 
			
		||||
    EMA implementation according to TensorBoard.
 | 
			
		||||
    """
 | 
			
		||||
    last = scalars[0]
 | 
			
		||||
    smoothed = list()
 | 
			
		||||
    for next_val in scalars:
 | 
			
		||||
        smoothed_val = last * weight + (1 - weight) * next_val
 | 
			
		||||
        smoothed.append(smoothed_val)
 | 
			
		||||
        last = smoothed_val
 | 
			
		||||
    return smoothed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
 | 
			
		||||
    import matplotlib.pyplot as plt
 | 
			
		||||
    with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
 | 
			
		||||
        data = json.load(f)
 | 
			
		||||
 | 
			
		||||
    for key in keys:
 | 
			
		||||
        steps, metrics = [], []
 | 
			
		||||
        for i in range(len(data["log_history"])):
 | 
			
		||||
            if key in data["log_history"][i]:
 | 
			
		||||
                steps.append(data["log_history"][i]["step"])
 | 
			
		||||
                metrics.append(data["log_history"][i][key])
 | 
			
		||||
 | 
			
		||||
        if len(metrics) == 0:
 | 
			
		||||
            logger.warning(f"No metric {key} to plot.")
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        plt.figure()
 | 
			
		||||
        plt.plot(steps, metrics, alpha=0.4, label="original")
 | 
			
		||||
        plt.plot(steps, smooth(metrics), label="smoothed")
 | 
			
		||||
        plt.title("training {} of {}".format(key, save_dictionary))
 | 
			
		||||
        plt.xlabel("step")
 | 
			
		||||
        plt.ylabel(key)
 | 
			
		||||
        plt.legend()
 | 
			
		||||
        plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
 | 
			
		||||
        print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
 | 
			
		||||
@ -1,60 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Dict, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from transformers import DataCollatorWithPadding
 | 
			
		||||
 | 
			
		||||
from .peft_trainer import PeftTrainer
 | 
			
		||||
 | 
			
		||||
from .other import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
 | 
			
		||||
    preds, _ = eval_preds
 | 
			
		||||
    return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
 | 
			
		||||
    r"""
 | 
			
		||||
    Data collator for pairwise data.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Pads batched data to the longest sequence in the batch.
 | 
			
		||||
 | 
			
		||||
        We generate 2 * n examples where the first n examples represent chosen examples and
 | 
			
		||||
        the last n examples represent rejected examples.
 | 
			
		||||
        """
 | 
			
		||||
        features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
 | 
			
		||||
        return super().__call__(features)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PairwisePeftTrainer(PeftTrainer):
 | 
			
		||||
    r"""
 | 
			
		||||
    Inherits PeftTrainer to compute pairwise loss.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.can_return_loss = True # override property to return eval_loss
 | 
			
		||||
 | 
			
		||||
    def compute_loss(self, model, inputs, return_outputs=False):
 | 
			
		||||
        r"""
 | 
			
		||||
        Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
 | 
			
		||||
 | 
			
		||||
        We use score on the EOS token to represent reward of the whole sentence.
 | 
			
		||||
 | 
			
		||||
        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
 | 
			
		||||
 | 
			
		||||
        Note that the first element will be removed from the output tuple.
 | 
			
		||||
 | 
			
		||||
        See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
 | 
			
		||||
        """
 | 
			
		||||
        batch_size = inputs["input_ids"].size(0) // 2
 | 
			
		||||
        _, _, values = model(**inputs)
 | 
			
		||||
        r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
 | 
			
		||||
        loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
 | 
			
		||||
        return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
 | 
			
		||||
@ -2,83 +2,26 @@
 | 
			
		||||
# Implements user interface in browser for fine-tuned models.
 | 
			
		||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import mdtex2html
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from utils import (
 | 
			
		||||
    Template,
 | 
			
		||||
    load_pretrained,
 | 
			
		||||
    prepare_infer_args,
 | 
			
		||||
    get_logits_processor
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from transformers import TextIteratorStreamer
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
 | 
			
		||||
model, tokenizer = load_pretrained(model_args, finetuning_args)
 | 
			
		||||
model_args, data_args, finetuning_args, generating_args = get_infer_args()
 | 
			
		||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
prompt_template = Template(data_args.prompt_template)
 | 
			
		||||
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def postprocess(self, y):
 | 
			
		||||
    r"""
 | 
			
		||||
    Overrides Chatbot.postprocess
 | 
			
		||||
    """
 | 
			
		||||
    if y is None:
 | 
			
		||||
        return []
 | 
			
		||||
    for i, (message, response) in enumerate(y):
 | 
			
		||||
        y[i] = (
 | 
			
		||||
            None if message is None else mdtex2html.convert((message)),
 | 
			
		||||
            None if response is None else mdtex2html.convert(response),
 | 
			
		||||
        )
 | 
			
		||||
    return y
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
gr.Chatbot.postprocess = postprocess
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    lines = [line for line in lines if line != ""]
 | 
			
		||||
    count = 0
 | 
			
		||||
    for i, line in enumerate(lines):
 | 
			
		||||
        if "```" in line:
 | 
			
		||||
            count += 1
 | 
			
		||||
            items = line.split("`")
 | 
			
		||||
            if count % 2 == 1:
 | 
			
		||||
                lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
 | 
			
		||||
            else:
 | 
			
		||||
                lines[i] = "<br /></code></pre>"
 | 
			
		||||
        else:
 | 
			
		||||
            if i > 0:
 | 
			
		||||
                if count % 2 == 1:
 | 
			
		||||
                    line = line.replace("`", "\`")
 | 
			
		||||
                    line = line.replace("<", "<")
 | 
			
		||||
                    line = line.replace(">", ">")
 | 
			
		||||
                    line = line.replace(" ", " ")
 | 
			
		||||
                    line = line.replace("*", "*")
 | 
			
		||||
                    line = line.replace("_", "_")
 | 
			
		||||
                    line = line.replace("-", "-")
 | 
			
		||||
                    line = line.replace(".", ".")
 | 
			
		||||
                    line = line.replace("!", "!")
 | 
			
		||||
                    line = line.replace("(", "(")
 | 
			
		||||
                    line = line.replace(")", ")")
 | 
			
		||||
                    line = line.replace("$", "$")
 | 
			
		||||
                lines[i] = "<br />" + line
 | 
			
		||||
    text = "".join(lines)
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
 | 
			
		||||
    chatbot.append((parse_text(query), ""))
 | 
			
		||||
    chatbot.append((query, ""))
 | 
			
		||||
 | 
			
		||||
    input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
 | 
			
		||||
    input_ids = input_ids.to(model.device)
 | 
			
		||||
@ -102,7 +45,7 @@ def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
 | 
			
		||||
    for new_text in streamer:
 | 
			
		||||
        response += new_text
 | 
			
		||||
        new_history = history + [(query, response)]
 | 
			
		||||
        chatbot[-1] = (parse_text(query), parse_text(response))
 | 
			
		||||
        chatbot[-1] = (query, response)
 | 
			
		||||
        yield chatbot, new_history
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user