release v0.1.0

Former-commit-id: f8193e8009451cf569a28a10eb4bd88831844441
This commit is contained in:
hiyouga 2023-07-18 00:18:25 +08:00
parent 799524b37b
commit 091805d38e
30 changed files with 1513 additions and 309 deletions

View File

@ -10,7 +10,9 @@
## Changelog ## Changelog
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base`, `--padding_side right` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
@ -125,14 +127,10 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### LLaMA Weights Preparation (optional) ### All-in-one Web UI
1. Download the weights of the LLaMA models.
2. Convert them to HF format using the following command.
```bash ```bash
python -m transformers.models.llama.convert_llama_weights_to_hf \ python src/train_web.py
--input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model
``` ```
### (Continually) Pre-Training ### (Continually) Pre-Training
@ -275,10 +273,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### API / CLI / Web Demo ### API Demo
```bash ```bash
python src/xxx_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
See `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```

View File

@ -3,14 +3,14 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.19.0 accelerate>=0.19.0
peft>=0.3.0 peft>=0.3.0
trl==0.4.4 trl>=0.4.7
sentencepiece sentencepiece
jieba jieba
rouge-chinese rouge-chinese
nltk nltk
gradio>=3.36.0 gradio>=3.36.0
uvicorn uvicorn
pydantic==1.10.7 pydantic
fastapi fastapi
sse-starlette sse-starlette
matplotlib matplotlib

View File

@ -1,6 +1,7 @@
from llmtuner.api import create_app from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
from llmtuner.webui import create_ui
__version__ = "0.0.9" __version__ = "0.1.0"

View File

@ -1,3 +1,4 @@
import json
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -93,7 +94,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
@ -107,7 +108,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -115,7 +116,7 @@ def create_app():
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"
return app return app

View File

@ -5,3 +5,27 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json" FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
METHODS = ["full", "freeze", "lora"]
SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b",
"BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
}

View File

@ -2,6 +2,20 @@ import sys
import logging import logging
class LoggerHandler(logging.Handler):
def __init__(self):
super().__init__()
self.log = ""
def emit(self, record):
if record.name == "httpx":
return
log_entry = self.format(record)
self.log += log_entry
self.log += "\n\n"
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter( formatter = logging.Formatter(

View File

@ -1,4 +1,5 @@
import os import os
import math
import json import json
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import List, Optional from typing import List, Optional
@ -10,12 +11,13 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: def smooth(scalars: List[float]) -> List[float]:
r""" r"""
EMA implementation according to TensorBoard. EMA implementation according to TensorBoard.
""" """
last = scalars[0] last = scalars[0]
smoothed = list() smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars: for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val) smoothed.append(smoothed_val)

View File

@ -1,141 +1,29 @@
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
@dataclass
class Format:
prefix: str
prompt: str
sep: str
use_history: bool
templates: Dict[str, Format] = {}
@dataclass @dataclass
class Template: class Template:
name: str name: str
def __post_init__(self): def __post_init__(self):
if self.name in templates:
if self.name == "vanilla": self.prefix = templates[self.name].prefix
r""" self.prompt = templates[self.name].prompt
Supports language model inference without histories. self.sep = templates[self.name].sep
""" self.use_history = templates[self.name].use_history
self._register_template(
prefix="",
prompt="{query}",
sep="",
use_history=False
)
elif self.name == "default":
r"""
Default template.
"""
self._register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
elif self.name == "alpaca":
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
self._register_template(
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
elif self.name == "vicuna":
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
self._register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
)
elif self.name == "belle":
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
self._register_template(
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
elif self.name == "linly":
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
self._register_template(
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
elif self.name == "billa":
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
self._register_template(
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
elif self.name == "ziya":
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
self._register_template(
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
elif self.name == "aquila":
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
self._register_template(
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
elif self.name == "intern":
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
self._register_template(
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
elif self.name == "baichuan":
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
self._register_template(
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True
)
else: else:
raise ValueError("Template {} does not exist.".format(self.name)) raise ValueError("Template {} does not exist.".format(self.name))
@ -155,14 +43,6 @@ class Template:
""" """
return self._format_example(query, history, prefix) + [resp] return self._format_example(query, history, prefix) + [resp]
def _register_template(
self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True
) -> None:
self.prefix = prefix
self.prompt = prompt
self.sep = sep
self.use_history = use_history
def _format_example( def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
) -> List[str]: ) -> List[str]:
@ -179,3 +59,150 @@ class Template:
convs.append(self.sep + self.prompt.format(query=user_query)) convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp) convs.append(bot_resp)
return convs[:-1] # drop last return convs[:-1] # drop last
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history
)
r"""
Supports language model inference without histories.
"""
register_template(
name="vanilla",
prefix="",
prompt="{query}",
sep="",
use_history=False
)
r"""
Default template.
"""
register_template(
name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca
"""
register_template(
name="alpaca",
prefix="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.",
prompt="### Instruction:\n{query}\n\n### Response:\n",
sep="\n\n",
use_history=True
)
r"""
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
"""
register_template(
name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ",
sep="</s>",
use_history=True
)
r"""
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
"""
register_template(
name="belle",
prefix="",
prompt="Human: {query}\n\nBelle: ",
sep="\n\n",
use_history=True
)
r"""
Supports: https://github.com/CVI-SZU/Linly
"""
register_template(
name="linly",
prefix="",
prompt="User: {query}\nBot: ",
sep="\n",
use_history=True
)
r"""
Supports: https://github.com/Neutralzz/BiLLa
"""
register_template(
name="billa",
prefix="",
prompt="Human: {query}\nAssistant: ",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
"""
register_template(
name="ziya",
prefix="",
prompt="<human>:{query}\n<bot>:",
sep="\n",
use_history=True
)
r"""
Supports: https://huggingface.co/qhduan/aquilachat-7b
"""
register_template(
name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
prompt="Human: {query}###Assistant: ",
sep="###",
use_history=True
)
r"""
Supports: https://huggingface.co/internlm/internlm-chat-7b
"""
register_template(
name="intern",
prefix="",
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
sep="<eoa>\n",
use_history=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
"""
register_template(
name="baichuan",
prefix="",
prompt="<reserved_102>{query}<reserved_103>",
sep="",
use_history=True
)

View File

@ -28,7 +28,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl==0.4.4", "To fix: pip install trl==0.4.4") require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@ -25,7 +25,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r""" r"""
Inherits PPOTrainer. Inherits PPOTrainer.
""" """
def __init__( def __init__(
self, self,
training_args: Seq2SeqTrainingArguments, training_args: Seq2SeqTrainingArguments,
@ -46,12 +45,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r""" r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
""" """
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size total_train_batch_size = (
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
)
len_dataloader = len(self.dataloader) len_dataloader = len(self.dataloader)
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
num_examples = len(self.dataset) num_examples = len(self.dataset)
num_train_epochs = self.args.num_train_epochs num_train_epochs = self.args.num_train_epochs
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) max_steps = math.ceil(num_train_epochs * len_dataloader)
self.state.max_steps = max_steps self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs self.state.num_train_epochs = num_train_epochs
@ -62,9 +62,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples}") logger.info(f" Num examples = {num_examples}")
logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
@ -77,7 +77,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"eos_token_id": self.tokenizer.eos_token_id, "eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor() "logits_processor": get_logits_processor()
} }
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
@ -87,59 +87,45 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.log_callback.on_train_begin(self.args, self.state, self.control) self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
batch = next(dataiter)
steps_trained += 1
for _ in range(self.config.gradient_accumulation_steps): unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
batch = next(dataiter) # Get responses
steps_trained += 1 query_tensors = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
unwrapped_model.gradient_checkpointing_disable() queries, responses = [], []
unwrapped_model.config.use_cache = True for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Get response from model # Compute rewards
query_tensors: torch.Tensor = batch["input_ids"] replace_model(unwrapped_model, target="reward")
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs) with torch.no_grad():
queries: List[torch.Tensor] = []
responses: List[torch.Tensor] = []
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
if response_length < 2: # make response have at least 2 tokens
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
else:
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(**self.prepare_model_inputs(queries, responses))
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") # make sure the model is default at the end replace_model(unwrapped_model, target="default")
# Run PPO step # Run PPO step
unwrapped_model.gradient_checkpointing_enable() unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
stats = self.step(queries, responses, rewards) loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = { logs = dict(
"loss": round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
"reward": round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
"learning_rate": stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
"epoch": round(step / num_steps_per_epoch, 2) epoch=round(step / len_dataloader, 2)
} )
print(logs) print(logs)
logs["step"] = step logs["step"] = step
self.state.log_history.append(logs) self.state.log_history.append(logs)
@ -150,10 +136,14 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if (step+1) % self.args.save_steps == 0: # save checkpoint if (step+1) % self.args.save_steps == 0: # save checkpoint
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
if self.control.should_training_stop: if self.control.should_epoch_stop or self.control.should_training_stop:
break break
@torch.inference_mode() if steps_trained == len_dataloader:
dataiter = iter(self.dataloader)
steps_trained = 0
@torch.no_grad()
def generate( def generate(
self, self,
inputs: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor],

View File

@ -4,7 +4,8 @@
import math import math
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
@ -19,7 +20,8 @@ def run_ppo(
model_args: ModelArguments, model_args: ModelArguments,
data_args: DataArguments, data_args: DataArguments,
training_args: Seq2SeqTrainingArguments, training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@ -30,7 +32,7 @@ def run_ppo(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate, learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size, mini_batch_size=training_args.per_device_train_batch_size,
batch_size=training_args.per_device_train_batch_size, batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1, ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm max_grad_norm=training_args.max_grad_norm
@ -50,7 +52,7 @@ def run_ppo(
ppo_trainer = PPOPeftTrainer( ppo_trainer = PPOPeftTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
callbacks=[LogCallback()], callbacks=callbacks,
config=ppo_config, config=ppo_config,
model=model, model=model,
ref_model=None, ref_model=None,

View File

@ -2,7 +2,8 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from transformers import Seq2SeqTrainingArguments from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
@ -18,7 +19,8 @@ def run_rm(
model_args: ModelArguments, model_args: ModelArguments,
data_args: DataArguments, data_args: DataArguments,
training_args: Seq2SeqTrainingArguments, training_args: Seq2SeqTrainingArguments,
finetuning_args: FinetuningArguments finetuning_args: FinetuningArguments,
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@ -44,7 +46,7 @@ def run_rm(
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=[LogCallback()], callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**trainer_kwargs **trainer_kwargs
) )

View File

@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui

View File

@ -0,0 +1,79 @@
import os
from typing import List, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel):
def __init__(self):
self.model = None
self.tokenizer = None
self.generating_args = GeneratingArguments()
def load_model(
self, lang: str, model_name: str, checkpoints: list,
finetuning_type: str, template: str, quantization_bit: str
):
if self.model is not None:
yield ALERTS["err_exists"][lang]
return
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
prompt_template=template,
checkpoint_dir=checkpoint_dir,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
super().__init__(*get_infer_args(args))
yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str):
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None
torch_gc()
yield ALERTS["info_unloaded"][lang]
def predict(
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
max_new_tokens: int,
top_p: float,
temperature: float
):
chatbot.append([query, ""])
response = ""
for new_text in self.stream_chat(
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, response]
yield chatbot, new_history

View File

@ -0,0 +1,75 @@
import json
import os
from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False)
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([
os.path.isfile(os.path.join(save_dir, checkpoint, name))
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {}
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@ -0,0 +1,4 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab

View File

@ -0,0 +1,54 @@
from typing import Dict, Tuple
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box(
chat_model: WebChatModel
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
with gr.Box(visible=False) as chat_box:
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
query = gr.Textbox(show_label=False, lines=8)
with gr.Column(min_width=32, scale=1):
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1):
clear_btn = gr.Button()
max_new_tokens = gr.Slider(
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
history = gr.State([])
submit_btn.click(
chat_model.predict,
[chatbot, query, history, max_new_tokens, top_p, temperature],
[chatbot, history],
show_progress=True
).then(
lambda: gr.update(value=""), outputs=[query]
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict(
query=query,
submit_btn=submit_btn,
clear_btn=clear_btn,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature
)

View File

@ -0,0 +1,19 @@
import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from typing import Tuple
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)
with gr.Row():
preview_samples = gr.JSON(interactive=False)
close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
return preview_box, preview_count, preview_samples, close_btn

View File

@ -0,0 +1,60 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
max_samples = gr.Textbox(value="100000", interactive=True)
batch_size = gr.Slider(value=8, minimum=1, maximum=128, step=1, interactive=True)
quantization_bit = gr.Dropdown([8, 4])
predict = gr.Checkbox(value=True)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
output_box = gr.Markdown()
start_btn.click(
runner.run_eval,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_samples=max_samples,
batch_size=batch_size,
quantization_bit=quantization_bit,
predict=predict,
start_btn=start_btn,
stop_btn=stop_btn,
output_box=output_box
)

View File

@ -0,0 +1,47 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])
info_box = gr.Markdown()
chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
load_btn.click(
chat_model.load_model,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
quantization_bit
],
[info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
unload_btn.click(
chat_model.unload_model, [top_elems["lang"]], [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)
return dict(
quantization_bit=quantization_bit,
info_box=info_box,
load_btn=load_btn,
unload_btn=unload_btn,
**chat_elems
)

View File

@ -0,0 +1,94 @@
from typing import Dict
from transformers.trainer_utils import SchedulerType
import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
with gr.Row():
learning_rate = gr.Textbox(value="5e-5", interactive=True)
num_train_epochs = gr.Textbox(value="3.0", interactive=True)
max_samples = gr.Textbox(value="100000", interactive=True)
quantization_bit = gr.Dropdown([8, 4])
with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=128, step=1, interactive=True)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=32, step=1, interactive=True)
lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType], interactive=True
)
fp16 = gr.Checkbox(value=True)
with gr.Row():
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5, interactive=True)
save_steps = gr.Slider(value=100, minimum=10, maximum=2000, step=10, interactive=True)
with gr.Row():
start_btn = gr.Button()
stop_btn = gr.Button()
with gr.Row():
with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True)
output_box = gr.Markdown()
with gr.Column(scale=1):
loss_viewer = gr.Plot()
start_btn.click(
runner.run_train,
[
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["template"],
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)
output_box.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
)
return dict(
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
learning_rate=learning_rate,
num_train_epochs=num_train_epochs,
max_samples=max_samples,
quantization_bit=quantization_bit,
batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
start_btn=start_btn,
stop_btn=stop_btn,
output_dir=output_dir,
output_box=output_box,
loss_viewer=loss_viewer
)

View File

@ -0,0 +1,42 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
def create_top() -> Dict[str, Component]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3)
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), interactive=True, scale=1)
checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=4)
refresh_btn = gr.Button(scale=1)
model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then(
get_model_path, [model_name], [model_path]
) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints])
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
return dict(
lang=lang,
model_name=model_name,
model_path=model_path,
finetuning_type=finetuning_type,
template=template,
checkpoints=checkpoints,
refresh_btn=refresh_btn
)

18
src/llmtuner/webui/css.py Normal file
View File

@ -0,0 +1,18 @@
CSS = r"""
.modal-box {
position: fixed !important;
top: 50%;
left: 50%;
transform: translate(-50%, -50%); /* center horizontally */
max-width: 1000px;
max-height: 750px;
overflow-y: scroll !important;
background-color: var(--input-background-fill);
border: 2px solid black !important;
z-index: 1000;
}
.dark .modal-box {
border: 2px solid white !important;
}
"""

View File

@ -0,0 +1,54 @@
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
create_top,
create_sft_tab,
create_eval_tab,
create_infer_tab
)
from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks:
runner = Runner()
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top()
with gr.Tab("SFT"):
sft_elems = create_sft_tab(top_elems, runner)
with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner)
with gr.Tab("Inference"):
infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
manager = Manager(elem_list)
demo.load(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
top_elems["lang"].change(
manager.gen_label,
[top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()],
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)

View File

@ -0,0 +1,384 @@
LOCALES = {
"lang": {
"en": {
"label": "Lang"
},
"zh": {
"label": "语言"
}
},
"model_name": {
"en": {
"label": "Model name"
},
"zh": {
"label": "模型名称"
}
},
"model_path": {
"en": {
"label": "Model path",
"info": "Path to pretrained model or model identifier from Hugging Face."
},
"zh": {
"label": "模型路径",
"info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
}
},
"checkpoints": {
"en": {
"label": "Checkpoints"
},
"zh": {
"label": "模型断点"
}
},
"template": {
"en": {
"label": "Prompt template"
},
"zh": {
"label": "提示模板"
}
},
"refresh_btn": {
"en": {
"value": "Refresh checkpoints"
},
"zh": {
"value": "刷新断点"
}
},
"dataset_dir": {
"en": {
"label": "Data dir",
"info": "Path of the data directory."
},
"zh": {
"label": "数据路径",
"info": "数据文件夹的路径。"
}
},
"dataset": {
"en": {
"label": "Dataset"
},
"zh": {
"label": "数据集"
}
},
"preview_btn": {
"en": {
"value": "Preview"
},
"zh": {
"value": "预览"
}
},
"preview_count": {
"en": {
"label": "Count"
},
"zh": {
"label": "数量"
}
},
"preview_samples": {
"en": {
"label": "Samples"
},
"zh": {
"label": "样例"
}
},
"close_btn": {
"en": {
"value": "Close"
},
"zh": {
"value": "关闭"
}
},
"max_samples": {
"en": {
"label": "Max samples",
"info": "Maximum samples per dataset."
},
"zh": {
"label": "最大样本数",
"info": "每个数据集最多使用的样本数。"
}
},
"batch_size": {
"en": {
"label": "Batch size",
"info": "Number of samples to process per GPU."
},
"zh":{
"label": "批处理大小",
"info": "每块 GPU 上处理的样本数量。"
}
},
"quantization_bit": {
"en": {
"label": "Quantization bit",
"info": "Enable 4/8-bit model quantization."
},
"zh": {
"label": "量化",
"info": "启用 4/8 比特模型量化。"
}
},
"start_btn": {
"en": {
"value": "Start"
},
"zh": {
"value": "开始"
}
},
"stop_btn": {
"en": {
"value": "Abort"
},
"zh": {
"value": "中断"
}
},
"output_box": {
"en": {
"value": "Ready."
},
"zh": {
"value": "准备就绪。"
}
},
"finetuning_type": {
"en": {
"label": "Finetuning method"
},
"zh": {
"label": "微调方法"
}
},
"learning_rate": {
"en": {
"label": "Learning rate",
"info": "Initial learning rate for AdamW."
},
"zh": {
"label": "学习率",
"info": "AdamW 优化器的初始学习率。"
}
},
"num_train_epochs": {
"en": {
"label": "Epochs",
"info": "Total number of training epochs to perform."
},
"zh": {
"label": "训练轮数",
"info": "需要执行的训练总轮数。"
}
},
"gradient_accumulation_steps": {
"en": {
"label": "Gradient accumulation",
"info": "Number of gradient accumulation steps."
},
"zh": {
"label": "梯度累积",
"info": "梯度累积的步数。"
}
},
"lr_scheduler_type": {
"en": {
"label": "LR Scheduler",
"info": "Name of learning rate scheduler.",
},
"zh": {
"label": "学习率调节器",
"info": "采用的学习率调节器名称。"
}
},
"fp16": {
"en": {
"label": "fp16",
"info": "Whether to use fp16 mixed precision training."
},
"zh": {
"label": "fp16",
"info": "是否启用 FP16 混合精度训练。"
}
},
"logging_steps": {
"en": {
"label": "Logging steps",
"info": "Number of update steps between two logs."
},
"zh": {
"label": "日志间隔",
"info": "每两次日志输出间的更新步数。"
}
},
"save_steps": {
"en": {
"label": "Save steps",
"info": "Number of updates steps between two checkpoints."
},
"zh": {
"label": "保存间隔",
"info": "每两次断点保存间的更新步数。"
}
},
"output_dir": {
"en": {
"label": "Checkpoint name",
"info": "Directory to save checkpoint."
},
"zh": {
"label": "断点名称",
"info": "保存模型断点的文件夹名称。"
}
},
"loss_viewer": {
"en": {
"label": "Loss"
},
"zh": {
"label": "损失"
}
},
"predict": {
"en": {
"label": "Save predictions"
},
"zh": {
"label": "保存预测结果"
}
},
"info_box": {
"en": {
"value": "Model unloaded, please load a model first."
},
"zh": {
"value": "模型未加载,请先加载模型。"
}
},
"load_btn": {
"en": {
"value": "Load model"
},
"zh": {
"value": "加载模型"
}
},
"unload_btn": {
"en": {
"value": "Unload model"
},
"zh": {
"value": "卸载模型"
}
},
"query": {
"en": {
"placeholder": "Input..."
},
"zh": {
"placeholder": "输入..."
}
},
"submit_btn": {
"en": {
"value": "Submit"
},
"zh": {
"value": "提交"
}
},
"clear_btn": {
"en": {
"value": "Clear history"
},
"zh": {
"value": "清空历史"
}
},
"max_new_tokens": {
"en": {
"label": "Maximum new tokens"
},
"zh": {
"label": "最大生成长度"
}
},
"top_p": {
"en": {
"label": "Top-p"
},
"zh": {
"label": "Top-p 采样值"
}
},
"temperature": {
"en": {
"label": "Temperature"
},
"zh": {
"label": "温度系数"
}
}
}
ALERTS = {
"err_conflict": {
"en": "A process is in running, please abort it firstly.",
"zh": "任务已存在,请先中断训练。"
},
"err_exists": {
"en": "You have loaded a model, please unload it first.",
"zh": "模型已存在,请先卸载模型。"
},
"err_no_model": {
"en": "Please select a model.",
"zh": "请选择模型。"
},
"err_no_path": {
"en": "Model not found.",
"zh": "模型未找到。"
},
"err_no_dataset": {
"en": "Please choose a dataset.",
"zh": "请选择数据集。"
},
"info_aborting": {
"en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……"
},
"info_aborted": {
"en": "Ready.",
"zh": "准备就绪。"
},
"info_finished": {
"en": "Finished.",
"zh": "训练完毕。"
},
"info_loading": {
"en": "Loading model...",
"zh": "加载中……"
},
"info_unloading": {
"en": "Unloading model...",
"zh": "卸载中……"
},
"info_loaded": {
"en": "Model loaded, now you can chat with your model!",
"zh": "模型已加载,可以开始聊天了!"
},
"info_unloaded": {
"en": "Model unloaded.",
"zh": "模型已卸载。"
}
}

View File

@ -0,0 +1,35 @@
import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.utils import get_time
class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]:
refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()}
}
user_config = load_config()
if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]:
update_dict = {}
refresh_dict = self.gen_refresh()
for elems in self.elem_list:
for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
return update_dict

View File

@ -0,0 +1,177 @@
import logging
import os
import threading
import time
import transformers
from typing import Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
class Runner:
def __init__(self):
self.aborted = False
self.running = False
def set_abort(self):
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
if not model_name:
return None, ALERTS["err_no_model"][lang], None, None
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None
self.aborted = False
self.running = True
logger_handler = LoggerHandler()
logger_handler.setLevel(logging.INFO)
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
self.running = False
torch_gc()
if self.aborted:
return ALERTS["info_aborted"][lang]
else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
def run_train(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
lr_scheduler_type, logging_steps, save_steps, output_dir
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None
args = dict(
model_name_or_path=model_name_or_path,
do_train=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
logging_steps=logging_steps,
save_steps=save_steps,
learning_rate=float(learning_rate),
num_train_epochs=float(num_train_epochs),
fp16=fp16,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang)
def run_eval(
self, lang, model_name, checkpoints, finetuning_type, template,
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
):
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints:
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
else:
checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict(
model_name_or_path=model_name_or_path,
do_eval=True,
finetuning_type=finetuning_type,
prompt_template=template,
dataset=",".join(dataset),
dataset_dir=dataset_dir,
max_samples=int(max_samples),
output_dir=output_dir,
checkpoint_dir=checkpoint_dir,
overwrite_cache=True,
predict_with_generate=True,
per_device_eval_batch_size=batch_size,
quantization_bit=int(quantization_bit) if quantization_bit else None
)
if predict:
args.pop("do_eval", None)
args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))

View File

@ -0,0 +1,74 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import Tuple
from datetime import datetime
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
def format_info(log: str, tracker: dict) -> str:
info = log
if "current_steps" in tracker:
info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
)
return info
def get_time() -> str:
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
def can_preview(dataset_dir: str, dataset: list) -> dict:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
):
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, dict]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data = json.load(f)
return len(data), data[:2], gr.update(visible=True)
def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
if not os.path.isfile(log_file):
return None
plt.close("all")
fig = plt.figure()
ax = fig.add_subplot(111)
steps, losses = [], []
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
log_info = json.loads(line)
if log_info.get("loss", None):
steps.append(log_info["current_steps"])
losses.append(log_info["loss"])
ax.plot(steps, losses, alpha=0.4, label="original")
ax.plot(steps, smooth(losses), label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
return fig

11
src/train_web.py Normal file
View File

@ -0,0 +1,11 @@
from llmtuner import create_ui
def main():
demo = create_ui()
demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
if __name__ == "__main__":
main()

View File

@ -1,95 +0,0 @@
# coding=utf-8
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
from transformers.utils.versions import require_version
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
model_args, data_args, finetuning_args, generating_args = get_infer_args()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((query, ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids,
"top_p": top_p,
"temperature": temperature,
"max_new_tokens": max_new_tokens,
"logits_processor": get_logits_processor(),
"streamer": streamer
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = (query, response)
yield chatbot, new_history
def reset_user_input():
return gr.update(value="")
def reset_state():
return [], []
with gr.Blocks() as demo:
gr.HTML("""
<h1 align="center">
<a href="https://github.com/hiyouga/LLaMA-Efficient-Tuning" target="_blank">
LLaMA Efficient Tuning
</a>
</h1>
""")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
label="Maximum new tokens", interactive=True)
top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)