diff --git a/README.md b/README.md index 975a783f..1124ade2 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ | ---------------------- | -------------- | ----------------- | ---- | ----- | | Pre-Training | ✅ | ✅ | ✅ | ✅ | | Supervised Fine-Tuning | ✅ | ✅ | ✅ | ✅ | -| Reward Model Training | | | ✅ | ✅ | +| Reward Modeling | | | ✅ | ✅ | | PPO Training | | | ✅ | ✅ | | DPO Training | ✅ | | ✅ | ✅ | @@ -103,7 +103,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) -- For reward modelling or DPO training: +- For reward modeling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) @@ -206,7 +206,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --fp16 ``` -### Reward Model Training +### Reward Modeling ```bash CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 6734ab78..243468cc 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -37,7 +37,9 @@ def run_ppo( batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps, ppo_epochs=1, - max_grad_norm=training_args.max_grad_norm + max_grad_norm=training_args.max_grad_norm, + seed=training_args.seed, + optimize_cuda_cache=True ) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index bf1d18fb..04b0774b 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]: with open(get_config_path(), "r", encoding="utf-8") as f: return json.load(f) except: - return {"last_model": "", "path_dict": {}} + return {"lang": "", "last_model": "", "path_dict": {}} -def save_config(model_name: str, model_path: str) -> None: +def save_config(lang: str, model_name: str, model_path: str) -> None: os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) user_config = load_config() - user_config["last_model"] = model_name - user_config["path_dict"][model_name] = model_path + user_config["lang"] = lang or user_config["lang"] + if model_name: + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path with open(get_config_path(), "w", encoding="utf-8") as f: json.dump(user_config, f, indent=2, ensure_ascii=False) diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py index 5b86f396..32228b8e 100644 --- a/src/llmtuner/webui/components/__init__.py +++ b/src/llmtuner/webui/components/__init__.py @@ -1,5 +1,5 @@ from llmtuner.webui.components.top import create_top -from llmtuner.webui.components.sft import create_sft_tab +from llmtuner.webui.components.train import create_train_tab from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.export import create_export_tab diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 97a332c1..8611e280 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -20,22 +20,25 @@ def create_top() -> Dict[str, "Component"]: model_path = gr.Textbox(scale=3) with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) + finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) checkpoints = gr.Dropdown(multiselect=True, scale=5) refresh_btn = gr.Button(scale=1) with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown(["", "8", "4"], scale=1) - template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1) + quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) + template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) source_prefix = gr.Textbox(scale=2) + lang.change(save_config, [lang, model_name, model_path]) + model_name.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] ).then( get_model_path, [model_name], [model_path] ) # do not save config since the below line will save - model_path.change(save_config, [model_name, model_path]) + + model_path.change(save_config, [lang, model_name, model_path]) finetuning_type.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] @@ -43,7 +46,9 @@ def create_top() -> Dict[str, "Component"]: can_quantize, [finetuning_type], [quantization_bit] ) - refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False) + refresh_btn.click( + list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False + ) return dict( lang=lang, diff --git a/src/llmtuner/webui/components/sft.py b/src/llmtuner/webui/components/train.py similarity index 80% rename from src/llmtuner/webui/components/sft.py rename to src/llmtuner/webui/components/train.py index a4101d15..a4d593c8 100644 --- a/src/llmtuner/webui/components/sft.py +++ b/src/llmtuner/webui/components/train.py @@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType import gradio as gr -from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR +from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.utils import can_preview, get_preview, gen_plot @@ -12,7 +12,7 @@ if TYPE_CHECKING: from llmtuner.webui.runner import Runner -def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: +def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) @@ -40,7 +40,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) lr_scheduler_type = gr.Dropdown( - value="cosine", choices=[scheduler.value for scheduler in SchedulerType] + choices=[scheduler.value for scheduler in SchedulerType], value="cosine" ) max_grad_norm = gr.Textbox(value="1.0") val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) @@ -60,6 +60,20 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ lora_target = gr.Textbox(scale=2) resume_lora_training = gr.Checkbox(value=True, scale=1) + with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: + with gr.Row(): + rlhf_method = gr.Dropdown(choices=["None", "Reward Modeling", "PPO", "DPO"], value="None", scale=1) + dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) + reward_model = gr.Dropdown(scale=2) + refresh_btn = gr.Button(scale=1) + + refresh_btn.click( + list_checkpoint, + [top_elems["model_name"], top_elems["finetuning_type"]], + [reward_model], + queue=False + ) + with gr.Row(): cmd_preview_btn = gr.Button() start_btn = gr.Button() @@ -79,7 +93,7 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ with gr.Column(scale=1): loss_viewer = gr.Plot() - input_list = [ + input_components = [ top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], @@ -108,16 +122,19 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ lora_dropout, lora_target, resume_lora_training, + rlhf_method, + dpo_beta, + reward_model, output_dir ] - output_list = [ + output_components = [ output_box, process_bar ] - cmd_preview_btn.click(runner.preview_train, input_list, output_list) - start_btn.click(runner.run_train, input_list, output_list) + cmd_preview_btn.click(runner.preview_train, input_components, output_components) + start_btn.click(runner.run_train, input_components, output_components) stop_btn.click(runner.set_abort, queue=False) process_bar.change( @@ -152,6 +169,11 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[ lora_dropout=lora_dropout, lora_target=lora_target, resume_lora_training=resume_lora_training, + rlhf_tab=rlhf_tab, + rlhf_method=rlhf_method, + dpo_beta=dpo_beta, + reward_model=reward_model, + refresh_btn=refresh_btn, cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 8539e18c..0ae59224 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -3,7 +3,7 @@ from transformers.utils.versions import require_version from llmtuner.webui.components import ( create_top, - create_sft_tab, + create_train_tab, create_eval_tab, create_infer_tab, create_export_tab, @@ -24,8 +24,8 @@ def create_ui() -> gr.Blocks: with gr.Blocks(title="Web Tuner", css=CSS) as demo: top_elems = create_top() - with gr.Tab("SFT"): - sft_elems = create_sft_tab(top_elems, runner) + with gr.Tab("Train"): + train_elems = create_train_tab(top_elems, runner) with gr.Tab("Evaluate"): eval_elems = create_eval_tab(top_elems, runner) @@ -36,7 +36,7 @@ def create_ui() -> gr.Blocks: with gr.Tab("Export"): export_elems = create_export_tab(top_elems) - elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] + elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems] manager = Manager(elem_list) demo.load( @@ -59,7 +59,7 @@ def create_web_demo() -> gr.Blocks: chat_model = WebChatModel(lazy_init=False) with gr.Blocks(title="Web Demo", css=CSS) as demo: - lang = gr.Dropdown(choices=["en", "zh"], value="en") + lang = gr.Dropdown(choices=["en", "zh"], value="") _, _, _, chat_elems = create_chat_box(chat_model, visible=True) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index aba4acae..0c7abae0 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -335,6 +335,44 @@ LOCALES = { "info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。" } }, + "rlhf_tab": { + "en": { + "label": "RLHF configurations" + }, + "zh": { + "label": "RLHF 参数设置" + } + }, + "rlhf_method": { + "en": { + "label": "RLHF method", + "info": "The RLHF algorithm to adopt." + }, + "zh": { + "label": "RLHF 方法", + "info": "RLHF 阶段使用的算法。" + } + }, + "dpo_beta": { + "en": { + "label": "DPO beta", + "info": "Value of the beta parameter in the DPO loss." + }, + "zh": { + "label": "DPO beta 参数", + "info": "DPO 损失函数中 beta 超参数大小。" + } + }, + "reward_model": { + "en": { + "label": "Reward model", + "info": "Checkpoint of the reward model for PPO training." + }, + "zh": { + "label": "奖励模型", + "info": "PPO 训练中奖励模型的断点路径。" + } + }, "cmd_preview_btn": { "en": { "value": "Preview command" diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index c8f797a4..2d5a0a39 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -12,12 +12,18 @@ class Manager: def __init__(self, elem_list: List[Dict[str, Component]]): self.elem_list = elem_list - def gen_refresh(self) -> Dict[str, Any]: + def gen_refresh(self, lang: str) -> Dict[str, Any]: refresh_dict = { "dataset": {"choices": list_dataset()["choices"]}, "output_dir": {"value": get_time()} } + user_config = load_config() + if lang: + refresh_dict["lang"] = {"value": lang} + else: + refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"} + if user_config["last_model"]: refresh_dict["model_name"] = {"value": user_config["last_model"]} refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} @@ -26,10 +32,12 @@ class Manager: def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING update_dict = {} - refresh_dict = self.gen_refresh() + refresh_dict = self.gen_refresh(lang) for elems in self.elem_list: for name, component in elems.items(): - update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) + update_dict[component] = gr.update( + **LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {}) + ) return update_dict diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 3a28dae6..afef4736 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -91,6 +91,9 @@ class Runner: lora_dropout: float, lora_target: str, resume_lora_training: bool, + rlhf_method: str, + dpo_beta: float, + reward_model: str, output_dir: str ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: @@ -109,7 +112,7 @@ class Runner: overwrite_cache=True, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit else None, + quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, @@ -134,6 +137,21 @@ class Runner: output_dir=output_dir ) args[compute_type] = True + + if rlhf_method == "Reward Modeling": + args["stage"] = "rm" + args["resume_lora_training"] = False + elif rlhf_method == "PPO": + args["stage"] = "ppo" + args["resume_lora_training"] = False + args["reward_model"] = reward_model + args["padding_side"] = "left" + val_size = 0 + elif rlhf_method == "DPO": + args["stage"] = "dpo" + args["resume_lora_training"] = False + args["dpo_beta"] = dpo_beta + if val_size > 1e-6: args["val_size"] = val_size args["evaluation_strategy"] = "steps" @@ -176,7 +194,7 @@ class Runner: predict_with_generate=True, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit else None, + quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 152df6e3..362fa008 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -63,7 +63,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: - args["plot_loss"] = True + if args.get("do_train", None): + args["plot_loss"] = True cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] for k, v in args.items(): if v is not None and v != "":