update flashattn, fix ppo save model

This commit is contained in:
hiyouga
2023-09-11 17:25:36 +08:00
parent b218c271ed
commit 0fbece85a7
5 changed files with 105 additions and 518 deletions

View File

@@ -25,16 +25,16 @@ class SavePeftModelCallback(TrainerCallback):
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
if args.should_save:
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):