mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 18:32:14 +08:00
313 lines
13 KiB
Python
313 lines
13 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
from dataclasses import asdict, dataclass, field
|
|
|
|
|
|
@dataclass
|
|
class DatasetAttr:
|
|
|
|
load_from: str
|
|
dataset_name: Optional[str] = None
|
|
dataset_sha1: Optional[str] = None
|
|
source_prefix: Optional[str] = None
|
|
|
|
def __repr__(self) -> str:
|
|
return self.dataset_name
|
|
|
|
def __post_init__(self):
|
|
self.prompt_column = "instruction"
|
|
self.query_column = "input"
|
|
self.response_column = "output"
|
|
self.history_column = None
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
|
"""
|
|
model_name_or_path: str = field(
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
|
)
|
|
use_fast_tokenizer: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
|
)
|
|
use_auth_token: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
|
)
|
|
model_revision: Optional[str] = field(
|
|
default="main",
|
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
|
)
|
|
padding_side: Optional[Literal["left", "right"]] = field(
|
|
default="left",
|
|
metadata={"help": "The side on which the model should have padding applied."}
|
|
)
|
|
quantization_bit: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of bits to quantize the model."}
|
|
)
|
|
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
|
default="nf4",
|
|
metadata={"help": "Quantization data type to use in int4 training."}
|
|
)
|
|
double_quantization: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
|
)
|
|
compute_dtype: Optional[torch.dtype] = field(
|
|
default=None,
|
|
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
|
)
|
|
checkpoint_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
|
)
|
|
reward_model: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
|
)
|
|
resume_lora_training: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
|
)
|
|
plot_loss: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
|
|
|
if self.quantization_bit is not None:
|
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
|
"""
|
|
dataset: Optional[str] = field(
|
|
default="alpaca_zh",
|
|
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
|
)
|
|
dataset_dir: Optional[str] = field(
|
|
default="data",
|
|
metadata={"help": "The name of the folder containing datasets."}
|
|
)
|
|
split: Optional[str] = field(
|
|
default="train",
|
|
metadata={"help": "Which dataset split to use for training and evaluation."}
|
|
)
|
|
overwrite_cache: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
|
)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."}
|
|
)
|
|
max_source_length: Optional[int] = field(
|
|
default=512,
|
|
metadata={"help": "The maximum total input sequence length after tokenization."}
|
|
)
|
|
max_target_length: Optional[int] = field(
|
|
default=512,
|
|
metadata={"help": "The maximum total output sequence length after tokenization."}
|
|
)
|
|
max_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
|
)
|
|
eval_num_beams: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
|
)
|
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
|
)
|
|
source_prefix: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
|
)
|
|
dev_ratio: Optional[float] = field(
|
|
default=0,
|
|
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
|
)
|
|
prompt_template: Optional[str] = field(
|
|
default="default",
|
|
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
|
)
|
|
|
|
def init_for_training(self): # support mixing multiple datasets
|
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
|
dataset_info = json.load(f)
|
|
|
|
if self.source_prefix is not None:
|
|
prefix_list = self.source_prefix.split("|")
|
|
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
|
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
|
else:
|
|
prefix_list = [None] * len(dataset_names)
|
|
|
|
self.dataset_list: List[DatasetAttr] = []
|
|
for i, name in enumerate(dataset_names):
|
|
if name not in dataset_info:
|
|
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
|
|
|
if "hf_hub_url" in dataset_info[name]:
|
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
|
elif "script_url" in dataset_info[name]:
|
|
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
|
else:
|
|
dataset_attr = DatasetAttr(
|
|
"file",
|
|
dataset_name=dataset_info[name]["file_name"],
|
|
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
|
)
|
|
|
|
dataset_attr.source_prefix = prefix_list[i]
|
|
|
|
if "columns" in dataset_info[name]:
|
|
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
|
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
|
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
|
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
|
|
|
self.dataset_list.append(dataset_attr)
|
|
|
|
|
|
@dataclass
|
|
class FinetuningArguments:
|
|
"""
|
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
|
"""
|
|
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
|
default="lora",
|
|
metadata={"help": "Which fine-tuning method to use."}
|
|
)
|
|
num_hidden_layers: Optional[int] = field(
|
|
default=32,
|
|
metadata={"help": "Number of decoder blocks in the model. \
|
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
|
Falcon choices: [\"32\", \"60\"], \
|
|
Baichuan choices: [\"32\"]"}
|
|
)
|
|
num_layer_trainable: Optional[int] = field(
|
|
default=3,
|
|
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
|
)
|
|
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
|
default="mlp",
|
|
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
|
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
|
)
|
|
lora_rank: Optional[int] = field(
|
|
default=8,
|
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
|
)
|
|
lora_alpha: Optional[float] = field(
|
|
default=32.0,
|
|
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
|
)
|
|
lora_dropout: Optional[float] = field(
|
|
default=0.1,
|
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
|
)
|
|
lora_target: Optional[str] = field(
|
|
default="q_proj,v_proj",
|
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
|
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
|
|
|
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
|
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
|
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
|
|
|
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
|
|
|
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
|
|
|
def save_to_json(self, json_path: str):
|
|
"""Saves the content of this instance in JSON format inside `json_path`."""
|
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
f.write(json_string)
|
|
|
|
@classmethod
|
|
def load_from_json(cls, json_path: str):
|
|
"""Creates an instance from the content of `json_path`."""
|
|
with open(json_path, "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
return cls(**json.loads(text))
|
|
|
|
|
|
@dataclass
|
|
class GeneratingArguments:
|
|
"""
|
|
Arguments pertaining to specify the decoding parameters.
|
|
"""
|
|
do_sample: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
|
)
|
|
temperature: Optional[float] = field(
|
|
default=0.95,
|
|
metadata={"help": "The value used to modulate the next token probabilities."}
|
|
)
|
|
top_p: Optional[float] = field(
|
|
default=0.7,
|
|
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
|
)
|
|
top_k: Optional[int] = field(
|
|
default=50,
|
|
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
|
)
|
|
num_beams: Optional[int] = field(
|
|
default=1,
|
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
|
)
|
|
max_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
|
)
|
|
max_new_tokens: Optional[int] = field(
|
|
default=512,
|
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
|
)
|
|
repetition_penalty: Optional[float] = field(
|
|
default=1.0,
|
|
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
|
)
|
|
length_penalty: Optional[float] = field(
|
|
default=1.0,
|
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
|
|
)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
args = asdict(self)
|
|
if args.get("max_new_tokens", None):
|
|
args.pop("max_length", None)
|
|
return args
|