disentangle model from tuner and rename modules

Former-commit-id: 4736344eb1
This commit is contained in:
hiyouga
2023-11-15 16:29:09 +08:00
parent 8ee48a9c9e
commit 06a4820836
57 changed files with 324 additions and 263 deletions

View File

@@ -2,7 +2,7 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir

View File

@@ -4,7 +4,7 @@ import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
@@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar

View File

@@ -1,17 +1,20 @@
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure
import matplotlib.pyplot as plt
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if not callback.max_steps:
@@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str:
return "```json\n{}\n```\n".format(result)
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure":
if not base_model:
return
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")