mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
update ppo and demo in webui
Former-commit-id: 7537dd434f4c0f0bde06bd8c2ac69bf622772316
This commit is contained in:
parent
0ed0b8f9c5
commit
e4f97615f0
@ -120,10 +120,12 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
|
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||||
},
|
},
|
||||||
template="llama2_zh"
|
template="llama2_zh"
|
||||||
)
|
)
|
||||||
|
@ -25,9 +25,13 @@ class WebChatModel(ChatModel):
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
if not lazy_init:
|
|
||||||
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if demo_mode: # load openchat 3.5 by default
|
||||||
|
super().__init__(dict(model_name_or_path="openchat/openchat_3.5", template="openchat"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loaded(self) -> bool:
|
def loaded(self) -> bool:
|
||||||
return self.model is not None
|
return self.model is not None
|
||||||
@ -75,6 +79,11 @@ class WebChatModel(ChatModel):
|
|||||||
|
|
||||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||||
|
|
||||||
|
if self.demo_mode:
|
||||||
|
yield ALERTS["err_demo"][lang]
|
||||||
|
return
|
||||||
|
|
||||||
yield ALERTS["info_unloading"][lang]
|
yield ALERTS["info_unloading"][lang]
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
@ -38,7 +38,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
|||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||||
|
|
||||||
with gr.Tab("Evaluate"):
|
with gr.Tab("Evaluate & Predict"):
|
||||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||||
|
|
||||||
with gr.Tab("Chat"):
|
with gr.Tab("Chat"):
|
||||||
|
@ -31,7 +31,6 @@ class Runner:
|
|||||||
self.thread: "Thread" = None
|
self.thread: "Thread" = None
|
||||||
self.do_train = True
|
self.do_train = True
|
||||||
self.running_data: Dict["Component", Any] = None
|
self.running_data: Dict["Component", Any] = None
|
||||||
self.monitor_inputs: Dict[str, str] = None
|
|
||||||
""" State """
|
""" State """
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.running = False
|
self.running = False
|
||||||
@ -75,6 +74,7 @@ class Runner:
|
|||||||
|
|
||||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.running_data = None
|
||||||
self.running = False
|
self.running = False
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
@ -87,9 +87,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([get_save_dir(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||||
])
|
) for ckpt in get("top.checkpoints")])
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
@ -139,7 +139,10 @@ class Runner:
|
|||||||
args["upcast_layernorm"] = True
|
args["upcast_layernorm"] = True
|
||||||
|
|
||||||
if args["stage"] == "ppo":
|
if args["stage"] == "ppo":
|
||||||
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
args["reward_model"] = get_save_dir(
|
||||||
|
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||||
|
)
|
||||||
|
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||||
|
|
||||||
if args["stage"] == "dpo":
|
if args["stage"] == "dpo":
|
||||||
args["dpo_beta"] = get("train.dpo_beta")
|
args["dpo_beta"] = get("train.dpo_beta")
|
||||||
@ -157,9 +160,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([
|
checkpoint_dir = ",".join([get_save_dir(
|
||||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
get("top.model_name"), get("top.finetuning_type"), ckpt
|
||||||
])
|
) for ckpt in get("top.checkpoints")])
|
||||||
output_dir = get_save_dir(
|
output_dir = get_save_dir(
|
||||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||||
)
|
)
|
||||||
@ -216,7 +219,6 @@ class Runner:
|
|||||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
self.do_train, self.running_data = do_train, data
|
self.do_train, self.running_data = do_train, data
|
||||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
|
||||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
yield from self.monitor()
|
yield from self.monitor()
|
||||||
@ -235,7 +237,10 @@ class Runner:
|
|||||||
|
|
||||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
self.running = True
|
self.running = True
|
||||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
||||||
|
output_dir = self.running_data[self.manager.get_elem_by_name(
|
||||||
|
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||||
|
)]
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user