update web UI, support rm predict #210

Former-commit-id: ed0e186a134de816d6a9278f4e47baa6250a52d1
This commit is contained in:
hiyouga 2023-07-21 13:27:27 +08:00
parent 5ad443eaa6
commit f769c2d3fc
13 changed files with 192 additions and 27 deletions

View File

@ -143,8 +143,10 @@ def preprocess_dataset(
if stage == "pt": if stage == "pt":
preprocess_function = preprocess_pretrain_dataset preprocess_function = preprocess_pretrain_dataset
elif stage == "sft": elif stage == "sft":
preprocess_function = preprocess_unsupervised_dataset \ if not training_args.predict_with_generate:
if training_args.predict_with_generate else preprocess_supervised_dataset preprocess_function = preprocess_supervised_dataset
else:
preprocess_function = preprocess_unsupervised_dataset
elif stage == "rm": elif stage == "rm":
preprocess_function = preprocess_pairwise_dataset preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo": elif stage == "ppo":

View File

@ -54,7 +54,7 @@ def get_train_args(
assert not (training_args.do_train and training_args.predict_with_generate), \ assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training." "`predict_with_generate` cannot be set as True while training."
assert (not training_args.do_predict) or training_args.predict_with_generate, \ assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions." "Please enable `predict_with_generate` to save model predictions."
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \

View File

@ -4,7 +4,8 @@ from typing import Dict, Optional
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -49,9 +50,9 @@ class PeftTrainer(Seq2SeqTrainer):
else: else:
backbone_model = model backbone_model = model
if self.finetuning_args.finetuning_type == "lora": if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True backbone_model.config.use_cache = True
backbone_model.save_pretrained( backbone_model.save_pretrained(
output_dir, output_dir,
@ -61,6 +62,8 @@ class PeftTrainer(Seq2SeqTrainer):
backbone_model.config.use_cache = False backbone_model.config.use_cache = False
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n") f.write(self.args.to_json_string() + "\n")
@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer):
model = unwrap_model(self.model) model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora": if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter")) backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({ model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"), "summary.weight": getattr(model, "reward_head_weight"),

View File

@ -1,10 +1,17 @@
import os
import json
import torch import torch
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer): class PairwisePeftTrainer(PeftTrainer):
r""" r"""
Inherits PeftTrainer to compute pairwise loss. Inherits PeftTrainer to compute pairwise loss.
@ -36,3 +43,26 @@ class PairwisePeftTrainer(PeftTrainer):
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
def save_predictions(
self,
predict_results: PredictionOutput
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
writer.write("\n".join(res))

View File

@ -56,3 +56,10 @@ def run_rm(
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)

View File

@ -1,4 +1,5 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab

View File

@ -22,13 +22,9 @@ def create_chat_box(
with gr.Column(scale=1): with gr.Column(scale=1):
clear_btn = gr.Button() clear_btn = gr.Button()
max_new_tokens = gr.Slider( max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
) temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
history = gr.State([]) history = gr.State([])

View File

@ -0,0 +1,34 @@
from typing import Dict
import gradio as gr
from gradio.components import Component
from llmtuner.webui.utils import export_model
def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
export_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
max_shard_size,
save_dir
],
[info_box]
)
return dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)

View File

@ -57,7 +57,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
output_dir = gr.Textbox(interactive=True) output_dir = gr.Textbox()
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()

View File

@ -5,7 +5,8 @@ from llmtuner.webui.components import (
create_top, create_top,
create_sft_tab, create_sft_tab,
create_eval_tab, create_eval_tab,
create_infer_tab create_infer_tab,
create_export_tab
) )
from llmtuner.webui.css import CSS from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager
@ -30,7 +31,10 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Chat"): with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems) infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems] with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list) manager = Manager(elem_list)
demo.load( demo.load(

View File

@ -452,6 +452,34 @@ LOCALES = {
"zh": { "zh": {
"label": "温度系数" "label": "温度系数"
} }
},
"save_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
} }
} }
@ -477,6 +505,14 @@ ALERTS = {
"en": "Please choose a dataset.", "en": "Please choose a dataset.",
"zh": "请选择数据集。" "zh": "请选择数据集。"
}, },
"err_no_checkpoint": {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……" "zh": "训练中断,正在等待线程结束……"
@ -504,5 +540,13 @@ ALERTS = {
"info_unloaded": { "info_unloaded": {
"en": "Model unloaded.", "en": "Model unloaded.",
"zh": "模型已卸载。" "zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
} }
} }

View File

@ -3,7 +3,7 @@ import os
import threading import threading
import time import time
import transformers import transformers
from typing import List, Optional, Tuple from typing import Generator, List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE
@ -25,7 +25,9 @@ class Runner:
self.aborted = True self.aborted = True
self.running = False self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: def initialize(
self, lang: str, model_name: str, dataset: list
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running: if self.running:
return None, ALERTS["err_conflict"][lang], None, None return None, ALERTS["err_conflict"][lang], None, None
@ -50,7 +52,9 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str: def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str:
self.running = False self.running = False
torch_gc() torch_gc()
if self.aborted: if self.aborted:
@ -87,7 +91,7 @@ class Runner:
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
output_dir: str output_dir: str
): ) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
yield error yield error
@ -174,7 +178,7 @@ class Runner:
max_samples: str, max_samples: str,
batch_size: int, batch_size: int,
predict: bool predict: bool
): ) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
yield error yield error

View File

@ -3,11 +3,13 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple from typing import Any, Dict, Generator, List, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
def format_info(log: str, tracker: dict) -> str: def format_info(log: str, tracker: dict) -> str:
@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step") ax.set_xlabel("step")
ax.set_ylabel("loss") ax.set_ylabel("loss")
return fig return fig
def export_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang]