fix ppo callbacks

This commit is contained in:
hiyouga
2024-07-02 17:34:56 +08:00
parent 33f2ddb8b6
commit 4c296001c4
2 changed files with 5 additions and 5 deletions

View File

@@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding
from ...data import get_dataset
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_ref_model, create_reward_model
from .trainer import CustomPPOTrainer
@@ -59,7 +59,7 @@ def run_ppo(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [FixValueHeadModelCallback()],
callbacks=callbacks,
model=model,
reward_model=reward_model,
ref_model=ref_model,