This commit is contained in:
hiyouga
2024-01-09 18:31:27 +08:00
parent ebee4f6a2a
commit 4571068e1e
9 changed files with 78 additions and 50 deletions

View File

@@ -8,7 +8,7 @@ from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.callbacks import FixValueHeadModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.utils import create_ref_model, create_reward_model
@@ -79,7 +79,7 @@ def run_ppo(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
callbacks=callbacks + [FixValueHeadModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,