support save args in webui #2807 #3046

some ideas are borrowed from @marko1616
This commit is contained in:
hiyouga
2024-03-30 23:09:12 +08:00
parent 257f643a74
commit 7a086ed333
9 changed files with 219 additions and 80 deletions

View File

@@ -15,7 +15,7 @@ from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_config
from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
@@ -150,23 +150,21 @@ class Runner:
args["disable_tqdm"] = True
if args["finetuning_type"] == "freeze":
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
args["num_layer_trainable"] = get("train.num_layer_trainable")
args["name_module_trainable"] = get("train.name_module_trainable")
elif args["finetuning_type"] == "lora":
args["lora_rank"] = int(get("train.lora_rank"))
args["lora_alpha"] = int(get("train.lora_alpha"))
args["lora_dropout"] = float(get("train.lora_dropout"))
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["lora_rank"] = get("train.lora_rank")
args["lora_alpha"] = get("train.lora_alpha")
args["lora_dropout"] = get("train.lora_dropout")
args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["additional_target"] = get("train.additional_target") or None
if args["stage"] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = args["quantization_bit"] is None
else:
args["create_new_adapter"] = get("train.create_new_adapter")
if args["use_llama_pro"]:
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
args["num_layer_trainable"] = get("train.num_layer_trainable")
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
@@ -305,3 +303,33 @@ class Runner:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
def save_args(self, data: Dict[Component, Any]) -> Tuple[str, "gr.Slider"]:
error = self._initialize(data, do_train=True, from_preview=True)
if error:
gr.Warning(error)
return error, gr.Slider(visible=False)
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids:
config_dict[elem_id] = value
save_path = save_args(config_path, config_dict)
return ALERTS["info_config_saved"][lang] + save_path, gr.Slider(visible=False)
def load_args(self, lang: str, config_path: str) -> Dict[Component, Any]:
config_dict = load_args(config_path)
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {self.manager.get_elem_by_id("top.lang"): lang}
output_dict: Dict["Component", Any] = {}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict