mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix ppo callbacks
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding
|
||||
from ...data import get_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_ref_model, create_reward_model
|
||||
from .trainer import CustomPPOTrainer
|
||||
|
||||
@@ -59,7 +59,7 @@ def run_ppo(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
callbacks=callbacks,
|
||||
model=model,
|
||||
reward_model=reward_model,
|
||||
ref_model=ref_model,
|
||||
|
||||
Reference in New Issue
Block a user