support BLOOM models

This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent a72492e649
commit 740a5daf56
16 changed files with 134 additions and 90 deletions

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Implements parameter-efficient PPO training of fine-tuned LLaMA.
# Implements parameter-efficient PPO training of fine-tuned models.
# This code is inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
@@ -15,8 +15,8 @@ from utils import (
prepare_data,
load_pretrained,
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
DynamicDataCollatorWithPadding,
PPOPeftTrainer,
LogCallback,
plot_loss
)
@@ -29,7 +29,7 @@ def main():
dataset = prepare_data(model_args, data_args)
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
@@ -52,7 +52,7 @@ def main():
)
# Initialize our Trainer
ppo_trainer = PPOTrainerForLLaMA(
ppo_trainer = PPOPeftTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],