refactor export, fix #1190

Former-commit-id: ea82f8a82a7356bbdf204190d596d0b1c8ef1a84
This commit is contained in:
hiyouga 2023-10-15 16:01:48 +08:00
parent 0503d45782
commit f3fa47fa7d
9 changed files with 52 additions and 49 deletions

View File

@ -371,7 +371,7 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \ --export_dir path_to_export \
--fp16 --fp16
``` ```

View File

@ -1,5 +1,4 @@
from .data_args import DataArguments from .data_args import DataArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
from .general_args import GeneralArguments
from .generating_args import GeneratingArguments from .generating_args import GeneratingArguments
from .model_args import ModelArguments from .model_args import ModelArguments

View File

@ -8,6 +8,10 @@ class FinetuningArguments:
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."} metadata={"help": "Which fine-tuning method to use."}

View File

@ -46,6 +46,10 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."} metadata={"help": "Adopt scaled rotary positional embeddings."}
) )
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field( flash_attn: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."} metadata={"help": "Enable FlashAttention-2 for faster training."}
@ -54,14 +58,14 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
) )
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
) )
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
@ -70,9 +74,9 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
) )
upcast_layernorm: Optional[bool] = field( export_dir: Optional[str] = field(
default=False, default=None,
metadata={"help": "Whether to upcast the layernorm weights in fp32."} metadata={"help": "Path to the directory to save the exported model."}
) )
def __post_init__(self): def __post_init__(self):

View File

@ -5,7 +5,6 @@ import datasets
import transformers import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -13,8 +12,7 @@ from llmtuner.hparams import (
ModelArguments, ModelArguments,
DataArguments, DataArguments,
FinetuningArguments, FinetuningArguments,
GeneratingArguments, GeneratingArguments
GeneralArguments
) )
@ -39,16 +37,14 @@ def parse_train_args(
DataArguments, DataArguments,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
FinetuningArguments, FinetuningArguments,
GeneratingArguments, GeneratingArguments
GeneralArguments
]: ]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, ModelArguments,
DataArguments, DataArguments,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
FinetuningArguments, FinetuningArguments,
GeneratingArguments, GeneratingArguments
GeneralArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
@ -77,10 +73,9 @@ def get_train_args(
DataArguments, DataArguments,
Seq2SeqTrainingArguments, Seq2SeqTrainingArguments,
FinetuningArguments, FinetuningArguments,
GeneratingArguments, GeneratingArguments
GeneralArguments
]: ]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@ -96,36 +91,36 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training() data_args.init_for_training()
if general_args.stage != "pt" and data_args.template is None: if finetuning_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
if general_args.stage != "sft" and training_args.predict_with_generate: if finetuning_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.") raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.") raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]: if finetuning_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking: if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.") raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None: if finetuning_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
if general_args.stage == "ppo" and data_args.streaming: if finetuning_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.") raise ValueError("Streaming mode does not suppport PPO training currently.")
if general_args.stage == "ppo" and model_args.shift_attn: if finetuning_args.stage == "ppo" and model_args.shift_attn:
raise ValueError("PPO training is incompatible with S^2-Attn.") raise ValueError("PPO training is incompatible with S^2-Attn.")
if training_args.max_steps == -1 and data_args.streaming: if training_args.max_steps == -1 and data_args.streaming:
@ -205,7 +200,7 @@ def get_train_args(
# Set seed before initializing model. # Set seed before initializing model.
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args, general_args return model_args, data_args, training_args, finetuning_args, generating_args
def get_infer_args( def get_infer_args(

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm from llmtuner.tuner.rm import run_rm
@ -17,31 +17,32 @@ logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt": if finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft": elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm": elif finetuning_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks) run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo": elif finetuning_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo": elif finetuning_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else: else:
raise ValueError("Unknown task.") raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args) model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.config.use_cache = True
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left" tokenizer.init_kwargs["padding_side"] = "left"
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
try: try:
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(model_args.export_dir)
except: except:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")

View File

@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict = dict() elem_dict = dict()
with gr.Row(): with gr.Row():
save_dir = gr.Textbox() export_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button() export_btn = gr.Button()
@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem("top.finetuning_type"), engine.manager.get_elem("top.finetuning_type"),
engine.manager.get_elem("top.template"), engine.manager.get_elem("top.template"),
max_shard_size, max_shard_size,
save_dir export_dir
], ],
[info_box] [info_box]
) )
elem_dict.update(dict( elem_dict.update(dict(
save_dir=save_dir, export_dir=export_dir,
max_shard_size=max_shard_size, max_shard_size=max_shard_size,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box info_box=info_box

View File

@ -531,7 +531,7 @@ LOCALES = {
"label": "温度系数" "label": "温度系数"
} }
}, },
"save_dir": { "export_dir": {
"en": { "en": {
"label": "Export dir", "label": "Export dir",
"info": "Directory to save exported model." "info": "Directory to save exported model."
@ -587,7 +587,7 @@ ALERTS = {
"en": "Please select a checkpoint.", "en": "Please select a checkpoint.",
"zh": "请选择断点。" "zh": "请选择断点。"
}, },
"err_no_save_dir": { "err_no_export_dir": {
"en": "Please provide export dir.", "en": "Please provide export dir.",
"zh": "请填写导出目录" "zh": "请填写导出目录"
}, },

View File

@ -124,7 +124,7 @@ def save_model(
finetuning_type: str, finetuning_type: str,
template: str, template: str,
max_shard_size: int, max_shard_size: int,
save_dir: str export_dir: str
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if not model_name: if not model_name:
yield ALERTS["err_no_model"][lang] yield ALERTS["err_no_model"][lang]
@ -138,8 +138,8 @@ def save_model(
yield ALERTS["err_no_checkpoint"][lang] yield ALERTS["err_no_checkpoint"][lang]
return return
if not save_dir: if not export_dir:
yield ALERTS["err_no_save_dir"][lang] yield ALERTS["err_no_export_dir"][lang]
return return
args = dict( args = dict(
@ -147,7 +147,7 @@ def save_model(
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template, template=template,
output_dir=save_dir export_dir=export_dir
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]