mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 21:52:50 +08:00
refactor export, fix #1190
Former-commit-id: ea82f8a82a7356bbdf204190d596d0b1c8ef1a84
This commit is contained in:
parent
0503d45782
commit
f3fa47fa7d
@ -371,7 +371,7 @@ python src/export_model.py \
|
|||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export \
|
--export_dir path_to_export \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
from .general_args import GeneralArguments
|
|
||||||
from .generating_args import GeneratingArguments
|
from .generating_args import GeneratingArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
@ -8,6 +8,10 @@ class FinetuningArguments:
|
|||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
|
default="sft",
|
||||||
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
|
)
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
|
@ -46,6 +46,10 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||||
)
|
)
|
||||||
|
checkpoint_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||||
|
)
|
||||||
flash_attn: Optional[bool] = field(
|
flash_attn: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||||
@ -54,14 +58,14 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
checkpoint_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
|
||||||
)
|
|
||||||
reward_model: Optional[str] = field(
|
reward_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
)
|
)
|
||||||
|
upcast_layernorm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
|
)
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
@ -70,9 +74,9 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
export_dir: Optional[str] = field(
|
||||||
default=False,
|
default=None,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -5,7 +5,6 @@ import datasets
|
|||||||
import transformers
|
import transformers
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@ -13,8 +12,7 @@ from llmtuner.hparams import (
|
|||||||
ModelArguments,
|
ModelArguments,
|
||||||
DataArguments,
|
DataArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments,
|
GeneratingArguments
|
||||||
GeneralArguments
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -39,16 +37,14 @@ def parse_train_args(
|
|||||||
DataArguments,
|
DataArguments,
|
||||||
Seq2SeqTrainingArguments,
|
Seq2SeqTrainingArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments,
|
GeneratingArguments
|
||||||
GeneralArguments
|
|
||||||
]:
|
]:
|
||||||
parser = HfArgumentParser((
|
parser = HfArgumentParser((
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
DataArguments,
|
DataArguments,
|
||||||
Seq2SeqTrainingArguments,
|
Seq2SeqTrainingArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments,
|
GeneratingArguments
|
||||||
GeneralArguments
|
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
@ -77,10 +73,9 @@ def get_train_args(
|
|||||||
DataArguments,
|
DataArguments,
|
||||||
Seq2SeqTrainingArguments,
|
Seq2SeqTrainingArguments,
|
||||||
FinetuningArguments,
|
FinetuningArguments,
|
||||||
GeneratingArguments,
|
GeneratingArguments
|
||||||
GeneralArguments
|
|
||||||
]:
|
]:
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
@ -96,36 +91,36 @@ def get_train_args(
|
|||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||||
data_args.init_for_training()
|
data_args.init_for_training()
|
||||||
|
|
||||||
if general_args.stage != "pt" and data_args.template is None:
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||||
|
|
||||||
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||||
|
|
||||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||||
|
|
||||||
if general_args.stage in ["rm", "dpo"]:
|
if finetuning_args.stage in ["rm", "dpo"]:
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
if not dataset_attr.ranking:
|
if not dataset_attr.ranking:
|
||||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||||
|
|
||||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
if general_args.stage == "ppo" and data_args.streaming:
|
if finetuning_args.stage == "ppo" and data_args.streaming:
|
||||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||||
|
|
||||||
if general_args.stage == "ppo" and model_args.shift_attn:
|
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||||
|
|
||||||
if training_args.max_steps == -1 and data_args.streaming:
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
@ -205,7 +200,7 @@ def get_train_args(
|
|||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
transformers.set_seed(training_args.seed)
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(
|
def get_infer_args(
|
||||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||||
from llmtuner.tuner.pt import run_pt
|
from llmtuner.tuner.pt import run_pt
|
||||||
from llmtuner.tuner.sft import run_sft
|
from llmtuner.tuner.sft import run_sft
|
||||||
from llmtuner.tuner.rm import run_rm
|
from llmtuner.tuner.rm import run_rm
|
||||||
@ -17,31 +17,32 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||||
|
|
||||||
if general_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif general_args.stage == "sft":
|
elif finetuning_args.stage == "sft":
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
elif general_args.stage == "rm":
|
elif finetuning_args.stage == "rm":
|
||||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif general_args.stage == "ppo":
|
elif finetuning_args.stage == "ppo":
|
||||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
elif general_args.stage == "dpo":
|
elif finetuning_args.stage == "dpo":
|
||||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown task.")
|
raise ValueError("Unknown task.")
|
||||||
|
|
||||||
|
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||||
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
model.config.use_cache = True
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||||
try:
|
try:
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
except:
|
except:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
elem_dict = dict()
|
elem_dict = dict()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
save_dir = gr.Textbox()
|
export_dir = gr.Textbox()
|
||||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||||
|
|
||||||
export_btn = gr.Button()
|
export_btn = gr.Button()
|
||||||
@ -28,13 +28,13 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
engine.manager.get_elem("top.finetuning_type"),
|
engine.manager.get_elem("top.finetuning_type"),
|
||||||
engine.manager.get_elem("top.template"),
|
engine.manager.get_elem("top.template"),
|
||||||
max_shard_size,
|
max_shard_size,
|
||||||
save_dir
|
export_dir
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box]
|
||||||
)
|
)
|
||||||
|
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
save_dir=save_dir,
|
export_dir=export_dir,
|
||||||
max_shard_size=max_shard_size,
|
max_shard_size=max_shard_size,
|
||||||
export_btn=export_btn,
|
export_btn=export_btn,
|
||||||
info_box=info_box
|
info_box=info_box
|
||||||
|
@ -531,7 +531,7 @@ LOCALES = {
|
|||||||
"label": "温度系数"
|
"label": "温度系数"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"save_dir": {
|
"export_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Export dir",
|
"label": "Export dir",
|
||||||
"info": "Directory to save exported model."
|
"info": "Directory to save exported model."
|
||||||
@ -587,7 +587,7 @@ ALERTS = {
|
|||||||
"en": "Please select a checkpoint.",
|
"en": "Please select a checkpoint.",
|
||||||
"zh": "请选择断点。"
|
"zh": "请选择断点。"
|
||||||
},
|
},
|
||||||
"err_no_save_dir": {
|
"err_no_export_dir": {
|
||||||
"en": "Please provide export dir.",
|
"en": "Please provide export dir.",
|
||||||
"zh": "请填写导出目录"
|
"zh": "请填写导出目录"
|
||||||
},
|
},
|
||||||
|
@ -124,7 +124,7 @@ def save_model(
|
|||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
template: str,
|
template: str,
|
||||||
max_shard_size: int,
|
max_shard_size: int,
|
||||||
save_dir: str
|
export_dir: str
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if not model_name:
|
if not model_name:
|
||||||
yield ALERTS["err_no_model"][lang]
|
yield ALERTS["err_no_model"][lang]
|
||||||
@ -138,8 +138,8 @@ def save_model(
|
|||||||
yield ALERTS["err_no_checkpoint"][lang]
|
yield ALERTS["err_no_checkpoint"][lang]
|
||||||
return
|
return
|
||||||
|
|
||||||
if not save_dir:
|
if not export_dir:
|
||||||
yield ALERTS["err_no_save_dir"][lang]
|
yield ALERTS["err_no_export_dir"][lang]
|
||||||
return
|
return
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
@ -147,7 +147,7 @@ def save_model(
|
|||||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
template=template,
|
template=template,
|
||||||
output_dir=save_dir
|
export_dir=export_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user