mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
Release v0.1.6
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@@ -13,7 +14,7 @@ from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
from llmtuner.webui.utils import get_eval_results, update_process_bar
|
||||
|
||||
|
||||
class Runner:
|
||||
@@ -88,14 +89,16 @@ class Runner:
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
padding_side: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
output_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
@@ -133,9 +136,11 @@ class Runner:
|
||||
warmup_steps=warmup_steps,
|
||||
fp16=(compute_type == "fp16"),
|
||||
bf16=(compute_type == "bf16"),
|
||||
padding_side=padding_side,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=resume_lora_training,
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
@@ -150,18 +155,18 @@ class Runner:
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
yield logger_handler.log, update_process_bar(trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
yield self.finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
@@ -182,7 +187,7 @@ class Runner:
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
@@ -223,15 +228,15 @@ class Runner:
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback)
|
||||
yield logger_handler.log, update_process_bar(trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self.finalize(lang, finish_info)
|
||||
yield self.finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
Reference in New Issue
Block a user