mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
add pre-training script
This commit is contained in:
@@ -130,7 +130,7 @@ def load_pretrained(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: Optional[FinetuningArguments] = None,
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
@@ -142,11 +142,14 @@ def load_pretrained(
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
|
||||
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME))
|
||||
finetuning_args = FinetuningArguments.load_from_json(
|
||||
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
|
||||
|
||||
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method."
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -207,7 +210,7 @@ def load_pretrained(
|
||||
|
||||
|
||||
def prepare_args(
|
||||
stage: Literal["sft", "rm", "ppo"]
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
@@ -230,7 +233,7 @@ def prepare_args(
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
if stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in RM and PPO stages.")
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
@@ -290,7 +293,7 @@ def prepare_data(
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) # support json, jsonl and csv
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
@@ -299,7 +302,7 @@ def prepare_data(
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension,
|
||||
extension if extension in ["csv", "json"] else "text",
|
||||
data_files=data_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
@@ -313,6 +316,9 @@ def prepare_data(
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
if dataset.column_names[0] == "text": # for plaintext (in pre-training)
|
||||
dataset = dataset.rename_column("text", getattr(dataset_attr, "prompt_column"))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
@@ -340,7 +346,7 @@ def preprocess_data(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataTrainingArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
@@ -363,7 +369,7 @@ def preprocess_data(
|
||||
yield prompt, answer
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<s>??`
|
||||
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
|
||||
text_ids = tokenizer(examples["prompt"])["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
@@ -395,7 +401,7 @@ def preprocess_data(
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_evaluation_dataset(examples):
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `X <s>` and labels with format `Y <s>`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for prompt, answer in format_example(examples):
|
||||
@@ -436,7 +442,7 @@ def preprocess_data(
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_sft_dataset_example(example):
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
@@ -450,19 +456,19 @@ def preprocess_data(
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
||||
|
||||
def print_ppo_dataset_example(example):
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_evaluation_dataset \
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_evaluation_dataset
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
@@ -474,11 +480,13 @@ def preprocess_data(
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "sft":
|
||||
print_sft_dataset_example(dataset[0])
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_ppo_dataset_example(dataset[0])
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user