From f3fa47fa7d09cdd5b9ef1a0eecf2ee2bd533946d Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 15 Oct 2023 16:01:48 +0800 Subject: [PATCH] refactor export, fix #1190 Former-commit-id: ea82f8a82a7356bbdf204190d596d0b1c8ef1a84 --- README.md | 2 +- src/llmtuner/hparams/__init__.py | 1 - src/llmtuner/hparams/finetuning_args.py | 4 +++ src/llmtuner/hparams/model_args.py | 18 +++++++----- src/llmtuner/tuner/core/parser.py | 37 +++++++++++-------------- src/llmtuner/tuner/tune.py | 21 +++++++------- src/llmtuner/webui/components/export.py | 6 ++-- src/llmtuner/webui/locales.py | 4 +-- src/llmtuner/webui/utils.py | 8 +++--- 9 files changed, 52 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 9acf6e61..312364a1 100644 --- a/README.md +++ b/README.md @@ -371,7 +371,7 @@ python src/export_model.py \ --template default \ --finetuning_type lora \ --checkpoint_dir path_to_checkpoint \ - --output_dir path_to_export \ + --export_dir path_to_export \ --fp16 ``` diff --git a/src/llmtuner/hparams/__init__.py b/src/llmtuner/hparams/__init__.py index 0fabfa33..f0547cc5 100644 --- a/src/llmtuner/hparams/__init__.py +++ b/src/llmtuner/hparams/__init__.py @@ -1,5 +1,4 @@ from .data_args import DataArguments from .finetuning_args import FinetuningArguments -from .general_args import GeneralArguments from .generating_args import GeneratingArguments from .model_args import ModelArguments diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index a153e440..b0e99193 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -8,6 +8,10 @@ class FinetuningArguments: r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ + stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."} + ) finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( default="lora", metadata={"help": "Which fine-tuning method to use."} diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index f3ebfd39..85cf2319 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -46,6 +46,10 @@ class ModelArguments: default=None, metadata={"help": "Adopt scaled rotary positional embeddings."} ) + checkpoint_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + ) flash_attn: Optional[bool] = field( default=False, metadata={"help": "Enable FlashAttention-2 for faster training."} @@ -54,14 +58,14 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} ) - checkpoint_dir: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} - ) reward_model: Optional[str] = field( default=None, metadata={"help": "Path to the directory containing the checkpoints of the reward model."} ) + upcast_layernorm: Optional[bool] = field( + default=False, + metadata={"help": "Whether to upcast the layernorm weights in fp32."} + ) plot_loss: Optional[bool] = field( default=False, metadata={"help": "Whether to plot the training loss after fine-tuning or not."} @@ -70,9 +74,9 @@ class ModelArguments: default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} ) - upcast_layernorm: Optional[bool] = field( - default=False, - metadata={"help": "Whether to upcast the layernorm weights in fp32."} + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} ) def __post_init__(self): diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 56cf02eb..22f46266 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -5,7 +5,6 @@ import datasets import transformers from typing import Any, Dict, Optional, Tuple from transformers import HfArgumentParser, Seq2SeqTrainingArguments -from transformers.utils.versions import require_version from transformers.trainer_utils import get_last_checkpoint from llmtuner.extras.logging import get_logger @@ -13,8 +12,7 @@ from llmtuner.hparams import ( ModelArguments, DataArguments, FinetuningArguments, - GeneratingArguments, - GeneralArguments + GeneratingArguments ) @@ -39,16 +37,14 @@ def parse_train_args( DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, - GeneratingArguments, - GeneralArguments + GeneratingArguments ]: parser = HfArgumentParser(( ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, - GeneratingArguments, - GeneralArguments + GeneratingArguments )) return _parse_args(parser, args) @@ -77,10 +73,9 @@ def get_train_args( DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, - GeneratingArguments, - GeneralArguments + GeneratingArguments ]: - model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args) + model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) # Setup logging if training_args.should_log: @@ -96,36 +91,36 @@ def get_train_args( # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) data_args.init_for_training() - if general_args.stage != "pt" and data_args.template is None: + if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.") - if general_args.stage != "sft" and training_args.predict_with_generate: + if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") - if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: raise ValueError("Please enable `predict_with_generate` to save model predictions.") - if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": + if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": raise ValueError("RM and PPO stages can only be performed with the LoRA method.") - if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: + if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") - if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: + if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train: raise ValueError("PPO and DPO stages can only be performed at training.") - if general_args.stage in ["rm", "dpo"]: + if finetuning_args.stage in ["rm", "dpo"]: for dataset_attr in data_args.dataset_list: if not dataset_attr.ranking: raise ValueError("Please use ranked datasets for reward modeling or DPO training.") - if general_args.stage == "ppo" and model_args.reward_model is None: + if finetuning_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") - if general_args.stage == "ppo" and data_args.streaming: + if finetuning_args.stage == "ppo" and data_args.streaming: raise ValueError("Streaming mode does not suppport PPO training currently.") - if general_args.stage == "ppo" and model_args.shift_attn: + if finetuning_args.stage == "ppo" and model_args.shift_attn: raise ValueError("PPO training is incompatible with S^2-Attn.") if training_args.max_steps == -1 and data_args.streaming: @@ -205,7 +200,7 @@ def get_train_args( # Set seed before initializing model. transformers.set_seed(training_args.seed) - return model_args, data_args, training_args, finetuning_args, generating_args, general_args + return model_args, data_args, training_args, finetuning_args, generating_args def get_infer_args( diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py index 54f72e73..f0917f37 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/tuner/tune.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer +from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer from llmtuner.tuner.pt import run_pt from llmtuner.tuner.sft import run_sft from llmtuner.tuner.rm import run_rm @@ -17,31 +17,32 @@ logger = get_logger(__name__) def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): - model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args) + model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) callbacks = [LogCallback()] if callbacks is None else callbacks - if general_args.stage == "pt": + if finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) - elif general_args.stage == "sft": + elif finetuning_args.stage == "sft": run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) - elif general_args.stage == "rm": + elif finetuning_args.stage == "rm": run_rm(model_args, data_args, training_args, finetuning_args, callbacks) - elif general_args.stage == "ppo": + elif finetuning_args.stage == "ppo": run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) - elif general_args.stage == "dpo": + elif finetuning_args.stage == "dpo": run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) else: raise ValueError("Unknown task.") def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): - model_args, _, training_args, finetuning_args, _, _ = get_train_args(args) + model_args, _, finetuning_args, _ = get_infer_args(args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.config.use_cache = True tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" - model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) + model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) try: - tokenizer.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(model_args.export_dir) except: logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index fa4cc770..bfdc5dc8 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: elem_dict = dict() with gr.Row(): - save_dir = gr.Textbox() + export_dir = gr.Textbox() max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) export_btn = gr.Button() @@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: engine.manager.get_elem("top.finetuning_type"), engine.manager.get_elem("top.template"), max_shard_size, - save_dir + export_dir ], [info_box] ) elem_dict.update(dict( - save_dir=save_dir, + export_dir=export_dir, max_shard_size=max_shard_size, export_btn=export_btn, info_box=info_box diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 93005e52..831a3eff 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -531,7 +531,7 @@ LOCALES = { "label": "温度系数" } }, - "save_dir": { + "export_dir": { "en": { "label": "Export dir", "info": "Directory to save exported model." @@ -587,7 +587,7 @@ ALERTS = { "en": "Please select a checkpoint.", "zh": "请选择断点。" }, - "err_no_save_dir": { + "err_no_export_dir": { "en": "Please provide export dir.", "zh": "请填写导出目录" }, diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index ef324425..181f0810 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -124,7 +124,7 @@ def save_model( finetuning_type: str, template: str, max_shard_size: int, - save_dir: str + export_dir: str ) -> Generator[str, None, None]: if not model_name: yield ALERTS["err_no_model"][lang] @@ -138,8 +138,8 @@ def save_model( yield ALERTS["err_no_checkpoint"][lang] return - if not save_dir: - yield ALERTS["err_no_save_dir"][lang] + if not export_dir: + yield ALERTS["err_no_export_dir"][lang] return args = dict( @@ -147,7 +147,7 @@ def save_model( checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), finetuning_type=finetuning_type, template=template, - output_dir=save_dir + export_dir=export_dir ) yield ALERTS["info_exporting"][lang]