mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix bug in web ui
Former-commit-id: 6efa38be46ed536f80fc67002f23862edcb9df8d
This commit is contained in:
parent
e4f97615f0
commit
3f53155a90
@ -3,6 +3,9 @@ import logging
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -19,16 +22,10 @@ class LoggerHandler(logging.Handler):
|
||||
self.log += "\n\n"
|
||||
|
||||
|
||||
def reset_logging():
|
||||
r"""
|
||||
Removes basic config of root logger
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
"""
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
@ -41,3 +38,12 @@ def get_logger(name: str) -> logging.Logger:
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
@ -202,7 +202,6 @@ def load_model_and_tokenizer(
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage in ["rm", "ppo"]:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
vhead_path = (
|
||||
model_args.checkpoint_dir[-1] if model_args.checkpoint_dir is not None else model_args.model_name_or_path
|
||||
)
|
||||
|
@ -236,11 +236,12 @@ class Runner:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||
self.running = True
|
||||
lang = self.running_data[self.manager.get_elem_by_name("top.lang")]
|
||||
output_dir = self.running_data[self.manager.get_elem_by_name(
|
||||
lang = get("top.lang")
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||
)]
|
||||
))
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
|
Loading…
x
Reference in New Issue
Block a user