Release v0.1.6

This commit is contained in:
hiyouga
2023-08-11 23:25:57 +08:00
parent 156710a995
commit a48cb0d474
18 changed files with 127 additions and 41 deletions

View File

@@ -1,3 +1,4 @@
import gradio as gr
import logging
import os
import threading
@@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results
from llmtuner.webui.utils import get_eval_results, update_process_bar
class Runner:
@@ -88,14 +89,16 @@ class Runner:
save_steps: int,
warmup_steps: int,
compute_type: str,
padding_side: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
resume_lora_training: bool,
output_dir: str
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -133,9 +136,11 @@ class Runner:
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
padding_side=padding_side,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
@@ -150,18 +155,18 @@ class Runner:
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback)
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
yield self.finalize(lang, finish_info), gr.update(visible=False)
def run_eval(
self,
@@ -182,7 +187,7 @@ class Runner:
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
yield error, gr.update(visible=False)
return
if checkpoints:
@@ -223,15 +228,15 @@ class Runner:
thread.start()
while thread.is_alive():
time.sleep(1)
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang]
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield format_info(logger_handler.log, trainer_callback)
yield logger_handler.log, update_process_bar(trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self.finalize(lang, finish_info)
yield self.finalize(lang, finish_info), gr.update(visible=False)