mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
@@ -9,6 +9,7 @@ from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.data import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
from llmtuner.train.utils import create_ref_model, create_reward_model
|
||||
@@ -95,6 +96,8 @@ def run_ppo(
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
ppo_trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
Reference in New Issue
Block a user