release v0.6.1

This commit is contained in:
hiyouga
2024-03-29 11:36:08 +08:00
parent c1fe6ce782
commit ca793028c6
5 changed files with 95 additions and 57 deletions

View File

@@ -1,19 +1,15 @@
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
import math
from typing import TYPE_CHECKING, List, Optional
from torch.optim import AdamW
from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from trl import PPOConfig
from ...data import get_dataset
from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..utils import create_custom_optimzer, create_custom_scheduler, create_ref_model, create_reward_model
from ..utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -42,46 +38,6 @@ def run_ppo(
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
mini_batch_size=training_args.per_device_train_batch_size,
batch_size=backward_batch_size * finetuning_args.ppo_buffer_size,
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=finetuning_args.ppo_epochs,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)
# Create optimizer and scheduler
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
create_custom_scheduler(training_args, num_training_steps, optimizer)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
# Initialize our Trainer
ppo_trainer = CustomPPOTrainer(
model_args=model_args,
@@ -89,15 +45,12 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
reward_model=reward_model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Training