Former-commit-id: a5e823ae75556eaa3b52ce7a887a6e7838a1eb5f
This commit is contained in:
hiyouga 2024-03-28 18:31:17 +08:00
parent 59e6ebf039
commit 14b75a0b93

View File

@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
self.in_training = True self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.") if args.save_on_each_node:
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_local_process_zero: if args.save_on_each_node:
return if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,