use fp16 model, add logcallback

This commit is contained in:
hiyouga
2023-05-28 21:30:28 +08:00
parent 769c6ab56b
commit 0c9fda01e3
7 changed files with 112 additions and 10 deletions

View File

@@ -17,6 +17,7 @@ from utils import (
preprocess_data,
DataCollatorForLLaMA,
PPOTrainerForLLaMA,
LogCallback,
plot_loss
)
@@ -54,6 +55,7 @@ def main():
ppo_trainer = PPOTrainerForLLaMA(
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[LogCallback()],
config=ppo_config,
model=model,
ref_model=None,