Former-commit-id: b19c14870d30c57fbea81e9cfa737d762922c54b
This commit is contained in:
hiyouga 2024-03-28 18:31:17 +08:00
parent 89c400633a
commit eac2a5b1d3

View File

@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
self.in_training = True self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.") if args.save_on_each_node:
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_local_process_zero: if args.save_on_each_node:
return if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,