mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
some ideas are borrowed from @marko1616
This commit is contained in:
@@ -15,7 +15,7 @@ from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.logging import LoggerHandler
|
||||
from ..extras.misc import get_device_count, torch_gc
|
||||
from ..train import run_exp
|
||||
from .common import get_module, get_save_dir, load_config
|
||||
from .common import get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
@@ -150,23 +150,21 @@ class Runner:
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = int(get("train.lora_rank"))
|
||||
args["lora_alpha"] = int(get("train.lora_alpha"))
|
||||
args["lora_dropout"] = float(get("train.lora_dropout"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["lora_rank"] = get("train.lora_rank")
|
||||
args["lora_alpha"] = get("train.lora_alpha")
|
||||
args["lora_dropout"] = get("train.lora_dropout")
|
||||
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
if args["stage"] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["num_layer_trainable"] = get("train.num_layer_trainable")
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
@@ -305,3 +303,33 @@ class Runner:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
|
||||
|
||||
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
return error, gr.Slider(visible=False)
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
config_dict[elem_id] = value
|
||||
|
||||
save_path = save_args(config_path, config_dict)
|
||||
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
|
||||
|
||||
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
|
||||
config_dict = load_args(config_path)
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {self.manager.get_elem_by_id("top.lang"): lang}
|
||||
|
||||
output_dict: Dict["Component", Any] = {}
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
||||
Reference in New Issue
Block a user