mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
89c400633a
commit
eac2a5b1d3
@ -58,9 +58,17 @@ class LogCallback(TrainerCallback):
|
|||||||
self.in_training = True
|
self.in_training = True
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.max_steps = state.max_steps
|
self.max_steps = state.max_steps
|
||||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
if args.save_on_each_node:
|
||||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
if not state.is_local_process_zero:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||||
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
|
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
@ -112,8 +120,12 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
if not state.is_local_process_zero:
|
if args.save_on_each_node:
|
||||||
return
|
if not state.is_local_process_zero:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
logs = dict(
|
logs = dict(
|
||||||
current_steps=self.cur_steps,
|
current_steps=self.cur_steps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user