support BLOOM models

This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent a72492e649
commit 740a5daf56
16 changed files with 134 additions and 90 deletions

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Implements parameter-efficient training of a reward model based on LLaMA.
# Implements parameter-efficient training of reward models.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
@@ -10,8 +10,8 @@ from utils import (
prepare_data,
load_pretrained,
preprocess_data,
PairwiseDataCollatorForLLaMA,
PairwiseTrainerForLLaMA,
PairwiseDataCollatorWithPadding,
PairwisePeftTrainer,
LogCallback,
plot_loss
)
@@ -23,7 +23,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
training_args.remove_unused_columns = False # important for pairwise dataset
@@ -38,7 +38,7 @@ def main():
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwiseTrainerForLLaMA(
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
model=model,
args=training_args,