mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
release v0.1.0
This commit is contained in:
177
src/llmtuner/webui/runner.py
Normal file
177
src/llmtuner/webui/runner.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self):
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
|
||||
def set_abort(self):
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
if self.running:
|
||||
return None, ALERTS["err_conflict"][lang], None, None
|
||||
|
||||
if not model_name:
|
||||
return None, ALERTS["err_no_model"][lang], None, None
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
return None, ALERTS["err_no_path"][lang], None, None
|
||||
|
||||
if len(dataset) == 0:
|
||||
return None, ALERTS["err_no_dataset"][lang], None, None
|
||||
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
|
||||
logger_handler = LoggerHandler()
|
||||
logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(logger_handler)
|
||||
transformers.logging.add_handler(logger_handler)
|
||||
trainer_callback = LogCallback(self)
|
||||
|
||||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
return ALERTS["info_aborted"][lang]
|
||||
else:
|
||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
|
||||
def run_train(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, learning_rate, num_train_epochs, max_samples,
|
||||
fp16, quantization_bit, batch_size, gradient_accumulation_steps,
|
||||
lr_scheduler_type, logging_steps, save_steps, output_dir
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
fp16=fp16,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
)
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
|
||||
yield self.finalize(lang)
|
||||
|
||||
def run_eval(
|
||||
self, lang, model_name, checkpoints, finetuning_type, template,
|
||||
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
|
||||
):
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_eval=True,
|
||||
finetuning_type=finetuning_type,
|
||||
prompt_template=template,
|
||||
dataset=",".join(dataset),
|
||||
dataset_dir=dataset_dir,
|
||||
max_samples=int(max_samples),
|
||||
output_dir=output_dir,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
overwrite_cache=True,
|
||||
predict_with_generate=True,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None
|
||||
)
|
||||
|
||||
if predict:
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
Reference in New Issue
Block a user