update api and support abort eval in webui

This commit is contained in:
hiyouga
2024-05-04 15:59:15 +08:00
parent d4283bb6bf
commit ed8f8be752
11 changed files with 277 additions and 192 deletions

View File

@@ -2,6 +2,7 @@ import json
import logging
import os
import signal
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
@@ -91,6 +92,18 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
"""
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
@@ -100,14 +113,6 @@ class LogCallback(TrainerCallback):
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
@@ -126,9 +131,6 @@ class LogCallback(TrainerCallback):
r"""
Event called at the end of a training step.
"""
if args.should_save:
self._timing(cur_steps=state.global_step)
if self.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
@@ -152,6 +154,7 @@ class LogCallback(TrainerCallback):
if not args.should_save:
return
self._timing(cur_steps=state.global_step)
logs = dict(
current_steps=self.cur_steps,
total_steps=self.max_steps,
@@ -183,8 +186,17 @@ class LogCallback(TrainerCallback):
r"""
Event called after a prediction step.
"""
if self.do_train:
return
if self.aborted:
sys.exit(0)
if not args.should_save:
return
eval_dataloader = kwargs.pop("eval_dataloader", None)
if args.should_save and has_length(eval_dataloader) and not self.do_train:
if has_length(eval_dataloader):
if self.max_steps == 0:
self._reset(max_steps=len(eval_dataloader))
self._create_thread_pool(output_dir=args.output_dir)