mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
255 lines
11 KiB
Python
255 lines
11 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
import gradio as gr
|
|
from threading import Thread
|
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
|
|
|
import transformers
|
|
from transformers.trainer import TRAINING_ARGS_NAME
|
|
|
|
from llmtuner.extras.callbacks import LogCallback
|
|
from llmtuner.extras.constants import TRAINING_STAGES
|
|
from llmtuner.extras.logging import LoggerHandler
|
|
from llmtuner.extras.misc import torch_gc
|
|
from llmtuner.train import run_exp
|
|
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
|
from llmtuner.webui.locales import ALERTS
|
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
|
|
|
if TYPE_CHECKING:
|
|
from llmtuner.webui.manager import Manager
|
|
|
|
|
|
class Runner:
|
|
|
|
def __init__(self, manager: "Manager") -> None:
|
|
self.manager = manager
|
|
""" Resume """
|
|
self.thread: "Thread" = None
|
|
self.do_train = True
|
|
self.running_data: Dict["Component", Any] = None
|
|
self.monitor_inputs: Dict[str, str] = None
|
|
""" State """
|
|
self.aborted = False
|
|
self.running = False
|
|
""" Handler """
|
|
self.logger_handler = LoggerHandler()
|
|
self.logger_handler.setLevel(logging.INFO)
|
|
logging.root.addHandler(self.logger_handler)
|
|
transformers.logging.add_handler(self.logger_handler)
|
|
|
|
@property
|
|
def alive(self) -> bool:
|
|
return self.thread is not None
|
|
|
|
def set_abort(self) -> None:
|
|
self.aborted = True
|
|
self.running = False
|
|
|
|
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
|
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
|
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
|
|
|
if self.running:
|
|
return ALERTS["err_conflict"][lang]
|
|
|
|
if not model_name:
|
|
return ALERTS["err_no_model"][lang]
|
|
|
|
if not model_path:
|
|
return ALERTS["err_no_path"][lang]
|
|
|
|
if len(dataset) == 0:
|
|
return ALERTS["err_no_dataset"][lang]
|
|
|
|
self.aborted = False
|
|
self.logger_handler.reset()
|
|
self.trainer_callback = LogCallback(self)
|
|
return ""
|
|
|
|
def _finalize(self, lang: str, finish_info: str) -> str:
|
|
self.thread = None
|
|
self.running = False
|
|
torch_gc()
|
|
if self.aborted:
|
|
return ALERTS["info_aborted"][lang]
|
|
else:
|
|
return finish_info
|
|
|
|
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
|
user_config = load_config()
|
|
|
|
if get("top.checkpoints"):
|
|
checkpoint_dir = ",".join([
|
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
|
])
|
|
else:
|
|
checkpoint_dir = None
|
|
|
|
args = dict(
|
|
stage=TRAINING_STAGES[get("train.training_stage")],
|
|
model_name_or_path=get("top.model_path"),
|
|
do_train=True,
|
|
cache_dir=user_config.get("cache_dir", None),
|
|
checkpoint_dir=checkpoint_dir,
|
|
finetuning_type=get("top.finetuning_type"),
|
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
|
template=get("top.template"),
|
|
system_prompt=get("top.system_prompt"),
|
|
flash_attn=get("top.flash_attn"),
|
|
shift_attn=get("top.shift_attn"),
|
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
|
dataset_dir=get("train.dataset_dir"),
|
|
dataset=",".join(get("train.dataset")),
|
|
cutoff_len=get("train.cutoff_len"),
|
|
learning_rate=float(get("train.learning_rate")),
|
|
num_train_epochs=float(get("train.num_train_epochs")),
|
|
max_samples=int(get("train.max_samples")),
|
|
per_device_train_batch_size=get("train.batch_size"),
|
|
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
|
|
lr_scheduler_type=get("train.lr_scheduler_type"),
|
|
max_grad_norm=float(get("train.max_grad_norm")),
|
|
logging_steps=get("train.logging_steps"),
|
|
save_steps=get("train.save_steps"),
|
|
warmup_steps=get("train.warmup_steps"),
|
|
neft_alpha=get("train.neft_alpha"),
|
|
train_on_prompt=get("train.train_on_prompt"),
|
|
upcast_layernorm=get("train.upcast_layernorm"),
|
|
lora_rank=get("train.lora_rank"),
|
|
lora_dropout=get("train.lora_dropout"),
|
|
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
|
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
|
resume_lora_training=get("train.resume_lora_training"),
|
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
|
)
|
|
args[get("train.compute_type")] = True
|
|
args["disable_tqdm"] = True
|
|
|
|
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
|
args["resume_lora_training"] = (args["quantization_bit"] is not None)
|
|
|
|
if args["quantization_bit"] is not None:
|
|
args["upcast_layernorm"] = True
|
|
|
|
if args["stage"] == "ppo":
|
|
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
|
|
|
if args["stage"] == "dpo":
|
|
args["dpo_beta"] = get("train.dpo_beta")
|
|
|
|
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
|
args["val_size"] = get("train.val_size")
|
|
args["evaluation_strategy"] = "steps"
|
|
args["eval_steps"] = get("train.save_steps")
|
|
args["load_best_model_at_end"] = True
|
|
|
|
return args
|
|
|
|
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
|
user_config = load_config()
|
|
|
|
if get("top.checkpoints"):
|
|
checkpoint_dir = ",".join([
|
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
|
])
|
|
output_dir = get_save_dir(
|
|
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
|
)
|
|
else:
|
|
checkpoint_dir = None
|
|
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
|
|
|
args = dict(
|
|
stage="sft",
|
|
model_name_or_path=get("top.model_path"),
|
|
do_eval=True,
|
|
predict_with_generate=True,
|
|
cache_dir=user_config.get("cache_dir", None),
|
|
checkpoint_dir=checkpoint_dir,
|
|
finetuning_type=get("top.finetuning_type"),
|
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
|
template=get("top.template"),
|
|
system_prompt=get("top.system_prompt"),
|
|
flash_attn=get("top.flash_attn"),
|
|
shift_attn=get("top.shift_attn"),
|
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
|
dataset_dir=get("eval.dataset_dir"),
|
|
dataset=",".join(get("eval.dataset")),
|
|
cutoff_len=get("eval.cutoff_len"),
|
|
max_samples=int(get("eval.max_samples")),
|
|
per_device_eval_batch_size=get("eval.batch_size"),
|
|
max_new_tokens=get("eval.max_new_tokens"),
|
|
top_p=get("eval.top_p"),
|
|
temperature=get("eval.temperature"),
|
|
output_dir=output_dir
|
|
)
|
|
|
|
if get("eval.predict"):
|
|
args.pop("do_eval", None)
|
|
args["do_predict"] = True
|
|
|
|
return args
|
|
|
|
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
error = self._initialize(data, do_train)
|
|
if error:
|
|
gr.Warning(error)
|
|
yield error, gr.update(visible=False)
|
|
else:
|
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
|
yield gen_cmd(args), gr.update(visible=False)
|
|
|
|
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
error = self._initialize(data, do_train)
|
|
if error:
|
|
gr.Warning(error)
|
|
yield error, gr.update(visible=False)
|
|
else:
|
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
|
self.running = True
|
|
self.do_train, self.running_data = do_train, data
|
|
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
|
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
|
self.thread.start()
|
|
yield from self.monitor()
|
|
|
|
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
yield from self._preview(data, do_train=True)
|
|
|
|
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
yield from self._preview(data, do_train=False)
|
|
|
|
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
yield from self._launch(data, do_train=True)
|
|
|
|
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
yield from self._launch(data, do_train=False)
|
|
|
|
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
|
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
|
while self.thread.is_alive():
|
|
time.sleep(2)
|
|
if self.aborted:
|
|
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
|
else:
|
|
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
|
|
|
if self.do_train:
|
|
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
|
finish_info = ALERTS["info_finished"][lang]
|
|
else:
|
|
finish_info = ALERTS["err_failed"][lang]
|
|
else:
|
|
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
|
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
|
else:
|
|
finish_info = ALERTS["err_failed"][lang]
|
|
|
|
yield self._finalize(lang, finish_info), gr.update(visible=False)
|