From 4242897b782ac97516588ca102b1f6ef00b190e5 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 2 Aug 2023 23:17:36 +0800 Subject: [PATCH] modify code structure Former-commit-id: 08f180e78862cad902b6cdbbd8c86e39b5cacf8a --- src/api_demo.py | 6 +- src/cli_demo.py | 3 +- src/export_model.py | 8 +- src/llmtuner/__init__.py | 5 + src/llmtuner/api/__init__.py | 1 + src/llmtuner/api/app.py | 5 +- src/llmtuner/chat/stream_chat.py | 17 +--- src/llmtuner/extras/callbacks.py | 108 ++++++++++++++++------ src/llmtuner/tuner/__init__.py | 6 +- src/llmtuner/tuner/ppo/trainer.py | 27 +++--- src/llmtuner/tuner/ppo/workflow.py | 3 +- src/llmtuner/tuner/pt/workflow.py | 3 +- src/llmtuner/tuner/rm/workflow.py | 3 +- src/llmtuner/tuner/sft/workflow.py | 3 +- src/llmtuner/tuner/tune.py | 36 ++++++++ src/llmtuner/webui/__init__.py | 4 + src/llmtuner/webui/chat.py | 11 +-- src/llmtuner/webui/components/__init__.py | 1 + src/llmtuner/webui/components/export.py | 4 +- src/llmtuner/webui/interface.py | 1 + src/llmtuner/webui/runner.py | 32 ++----- src/llmtuner/webui/utils.py | 23 ++--- src/train_bash.py | 14 +-- src/train_web.py | 2 +- src/web_demo.py | 7 +- 25 files changed, 188 insertions(+), 145 deletions(-) create mode 100644 src/llmtuner/tuner/tune.py diff --git a/src/api_demo.py b/src/api_demo.py index d2d61197..f7649e7b 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -5,13 +5,11 @@ import uvicorn -from llmtuner import ChatModel -from llmtuner.api.app import create_app -from llmtuner.tuner import get_infer_args +from llmtuner import ChatModel, create_app def main(): - chat_model = ChatModel(*get_infer_args()) + chat_model = ChatModel() app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/cli_demo.py b/src/cli_demo.py index 80e61264..2d0aff7e 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -3,11 +3,10 @@ # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint from llmtuner import ChatModel -from llmtuner.tuner import get_infer_args def main(): - chat_model = ChatModel(*get_infer_args()) + chat_model = ChatModel() history = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") diff --git a/src/export_model.py b/src/export_model.py index 9e5fe242..a0a86996 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -2,15 +2,11 @@ # Exports the fine-tuned model. # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model -from llmtuner.tuner import get_train_args, load_model_and_tokenizer +from llmtuner import export_model def main(): - model_args, _, training_args, finetuning_args, _ = get_train_args() - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - model.save_pretrained(training_args.output_dir, max_shard_size="10GB") - tokenizer.save_pretrained(training_args.output_dir) - print("model and tokenizer have been saved at:", training_args.output_dir) + export_model() if __name__ == "__main__": diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index b3b28109..04ffbb67 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,4 +1,9 @@ +# Level: api, webui > chat > tuner > dsets > extras, hparams + +from llmtuner.api import create_app from llmtuner.chat import ChatModel +from llmtuner.tuner import export_model, run_exp +from llmtuner.webui import Manager, WebChatModel, create_ui, create_chat_box __version__ = "0.1.5" diff --git a/src/llmtuner/api/__init__.py b/src/llmtuner/api/__init__.py index e69de29b..b3ce183a 100644 --- a/src/llmtuner/api/__init__.py +++ b/src/llmtuner/api/__init__.py @@ -0,0 +1 @@ +from llmtuner.api.app import create_app diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 26e41eff..4fc5fc43 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -5,9 +5,8 @@ from contextlib import asynccontextmanager from sse_starlette import EventSourceResponse from typing import List, Tuple -from llmtuner.tuner import get_infer_args from llmtuner.extras.misc import torch_gc -from llmtuner.chat.stream_chat import ChatModel +from llmtuner.chat import ChatModel from llmtuner.api.protocol import ( Role, Finish, @@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI: if __name__ == "__main__": - chat_model = ChatModel(*get_infer_args()) + chat_model = ChatModel() app = create_app(chat_model) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 626e07e4..d5a5f1ad 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,30 +1,21 @@ import torch -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread from transformers import TextIteratorStreamer from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.template import get_template -from llmtuner.tuner import load_model_and_tokenizer - -if TYPE_CHECKING: - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer class ChatModel: - def __init__( - self, - model_args: "ModelArguments", - data_args: "DataArguments", - finetuning_args: "FinetuningArguments", - generating_args: "GeneratingArguments" - ) -> None: + def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model = dispatch_model(self.model) self.template = get_template(data_args.template) self.source_prefix = data_args.source_prefix - self.generating_args = generating_args def process_args( self, diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 9c45b31e..d325b0a8 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from datetime import timedelta from transformers import TrainerCallback +from transformers.trainer_utils import has_length if TYPE_CHECKING: from transformers import TrainingArguments, TrainerState, TrainerControl @@ -14,58 +15,105 @@ class LogCallback(TrainerCallback): def __init__(self, runner=None): self.runner = runner + self.in_training = False self.start_time = time.time() - self.tracker = {} + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + + def timing(self): + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 + remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the beginning of training. """ - self.start_time = time.time() + if state.is_local_process_zero: + self.in_training = True + self.start_time = time.time() + self.max_steps = state.max_steps - def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" - Event called at the beginning of a training step. If using gradient accumulation, one training step - might take several inputs. + Event called at the end of training. """ - if self.runner is not None and self.runner.aborted: - control.should_epoch_stop = True - control.should_training_stop = True + if state.is_local_process_zero: + self.in_training = False + self.cur_steps = 0 + self.max_steps = 0 def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of an substep during gradient accumulation. """ - if self.runner is not None and self.runner.aborted: + if state.is_local_process_zero and self.runner is not None and self.runner.aborted: control.should_epoch_stop = True control.should_training_stop = True + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if state.is_local_process_zero: + self.cur_steps = state.global_step + self.timing() + if self.runner is not None and self.runner.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): + r""" + Event called after a successful prediction. + """ + if state.is_local_process_zero and not self.in_training: + self.cur_steps = 0 + self.max_steps = 0 + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: r""" Event called after logging the last logs. """ - if not state.is_world_process_zero: + if not state.is_local_process_zero: return - cur_time = time.time() - cur_steps = state.log_history[-1].get("step") - elapsed_time = cur_time - self.start_time - avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 - remaining_steps = state.max_steps - cur_steps - remaining_time = remaining_steps * avg_time_per_step - self.tracker = { - "current_steps": cur_steps, - "total_steps": state.max_steps, - "loss": state.log_history[-1].get("loss", None), - "eval_loss": state.log_history[-1].get("eval_loss", None), - "predict_loss": state.log_history[-1].get("predict_loss", None), - "reward": state.log_history[-1].get("reward", None), - "learning_rate": state.log_history[-1].get("learning_rate", None), - "epoch": state.log_history[-1].get("epoch", None), - "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, - "elapsed_time": str(timedelta(seconds=int(elapsed_time))), - "remaining_time": str(timedelta(seconds=int(remaining_time))) - } + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time + ) os.makedirs(args.output_dir, exist_ok=True) with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: - f.write(json.dumps(self.tracker) + "\n") + f.write(json.dumps(logs) + "\n") + + def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a prediction step. + """ + eval_dataloader = kwargs.pop("eval_dataloader", None) + if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: + if self.max_steps == 0: + self.max_steps = len(eval_dataloader) + self.cur_steps += 1 + self.timing() diff --git a/src/llmtuner/tuner/__init__.py b/src/llmtuner/tuner/__init__.py index c329f39a..4d5a83e4 100644 --- a/src/llmtuner/tuner/__init__.py +++ b/src/llmtuner/tuner/__init__.py @@ -1,5 +1 @@ -from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer -from llmtuner.tuner.pt import run_pt -from llmtuner.tuner.sft import run_sft -from llmtuner.tuner.rm import run_rm -from llmtuner.tuner.ppo import run_ppo +from llmtuner.tuner.tune import export_model, run_exp diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 35d36787..3392aa4d 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -11,7 +11,6 @@ from trl.core import LengthSampler from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor - from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model @@ -90,14 +89,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): reward_meter = AverageMeter() self.log_callback.on_train_begin(self.args, self.state, self.control) - for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): + for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): batch = next(dataiter) steps_trained += 1 # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True - unwrapped_model.eval() # Get inputs queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) @@ -106,21 +104,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Cast to training mode unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False - unwrapped_model.train() # Run PPO step stats = self.step(queries, responses, rewards) loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) - if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: + self.state.global_step += 1 + self.log_callback.on_step_end(self.args, self.state, self.control) + + if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0: logs = dict( loss=round(loss_meter.avg, 4), reward=round(reward_meter.avg, 4), learning_rate=stats["ppo/learning_rate"], epoch=round(step / len_dataloader, 2) ) - print(logs) + tqdm.write(str(logs)) logs["step"] = step self.state.log_history.append(logs) self.log_callback.on_log(self.args, self.state, self.control) @@ -137,10 +137,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): dataiter = iter(self.dataloader) steps_trained = 0 + self.log_callback.on_train_end(self.args, self.state, self.control) + @torch.no_grad() def get_inputs( self, - inputs: Dict[str, torch.Tensor], + batch: Dict[str, torch.Tensor], length_sampler: Optional[Callable] = None, **generation_kwargs ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: @@ -152,7 +154,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.model, layer_norm_params = cast_layernorm_dtype(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs) + response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop @@ -161,7 +163,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): unwrapped_model.pretrained_model.generation_config._from_model_config = False queries, responses = [], [] - query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu() + query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 @@ -181,11 +183,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): Computes scores using given reward model. """ replace_model(unwrapped_model, target="reward") - _, _, values = self.model( - **self.prepare_model_inputs(queries, responses), - output_hidden_states=True, - return_dict=True - ) + batch = self.prepare_model_inputs(queries, responses) + _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type replace_model(unwrapped_model, target="default") return rewards diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index eed0707f..0ca8cbd4 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -10,7 +10,6 @@ from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler from llmtuner.dsets import get_dataset, preprocess_dataset -from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer @@ -25,7 +24,7 @@ def run_ppo( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] + callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 1dbb6852..2a9f8279 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer @@ -21,7 +20,7 @@ def run_pt( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] + callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index ec2b8ada..19527ce8 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional, List from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.rm.metric import compute_accuracy @@ -22,7 +21,7 @@ def run_rm( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] + callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 9a8feeb0..69d200f3 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset -from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss @@ -22,7 +21,7 @@ def run_sft( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] + callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/tuner/tune.py new file mode 100644 index 00000000..99f5d2a9 --- /dev/null +++ b/src/llmtuner/tuner/tune.py @@ -0,0 +1,36 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from llmtuner.extras.callbacks import LogCallback +from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer +from llmtuner.tuner.pt import run_pt +from llmtuner.tuner.sft import run_sft +from llmtuner.tuner.rm import run_rm +from llmtuner.tuner.ppo import run_ppo + +if TYPE_CHECKING: + from transformers import TrainerCallback + + +def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): + model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args) + callbacks = [LogCallback()] if callbacks is None else callbacks + + if general_args.stage == "pt": + run_pt(model_args, data_args, training_args, finetuning_args, callbacks) + elif general_args.stage == "sft": + run_sft(model_args, data_args, training_args, finetuning_args, callbacks) + elif general_args.stage == "rm": + run_rm(model_args, data_args, training_args, finetuning_args, callbacks) + elif general_args.stage == "ppo": + run_ppo(model_args, data_args, training_args, finetuning_args, callbacks) + + +def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): + model_args, _, training_args, finetuning_args, _ = get_train_args(args) + model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) + tokenizer.save_pretrained(training_args.output_dir) + + +if __name__ == "__main__": + run_exp() diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py index e69de29b..8544957c 100644 --- a/src/llmtuner/webui/__init__.py +++ b/src/llmtuner/webui/__init__.py @@ -0,0 +1,4 @@ +from llmtuner.webui.chat import WebChatModel +from llmtuner.webui.interface import create_ui +from llmtuner.webui.manager import Manager +from llmtuner.webui.components import create_chat_box diff --git a/src/llmtuner/webui/chat.py b/src/llmtuner/webui/chat.py index d3cdbecc..773fb7c7 100644 --- a/src/llmtuner/webui/chat.py +++ b/src/llmtuner/webui/chat.py @@ -1,22 +1,21 @@ import os -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple from llmtuner.chat.stream_chat import ChatModel from llmtuner.extras.misc import torch_gc from llmtuner.hparams import GeneratingArguments -from llmtuner.tuner import get_infer_args from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): - def __init__(self, *args): + def __init__(self, args: Optional[Dict[str, Any]]) -> None: self.model = None self.tokenizer = None self.generating_args = GeneratingArguments() - if len(args) != 0: - super().__init__(*args) + if args is not None: + super().__init__(args) def load_model( self, @@ -57,7 +56,7 @@ class WebChatModel(ChatModel): template=template, source_prefix=source_prefix ) - super().__init__(*get_infer_args(args)) + super().__init__(args) yield ALERTS["info_loaded"][lang] diff --git a/src/llmtuner/webui/components/__init__.py b/src/llmtuner/webui/components/__init__.py index 9312f409..5b86f396 100644 --- a/src/llmtuner/webui/components/__init__.py +++ b/src/llmtuner/webui/components/__init__.py @@ -3,3 +3,4 @@ from llmtuner.webui.components.sft import create_sft_tab from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.export import create_export_tab +from llmtuner.webui.components.chatbot import create_chat_box diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index 6e27ff16..295fc617 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Dict import gradio as gr -from llmtuner.webui.utils import export_model +from llmtuner.webui.utils import save_model if TYPE_CHECKING: from gradio.components import Component @@ -16,7 +16,7 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component info_box = gr.Textbox(show_label=False, interactive=False) export_btn.click( - export_model, + save_model, [ top_elems["lang"], top_elems["model_name"], diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index cf9b714b..afc50d6e 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -47,6 +47,7 @@ def create_ui() -> gr.Blocks: manager.gen_label, [top_elems["lang"]], [elem for elems in elem_list for elem in elems.values()], + queue=False ) return demo diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 08f72d0a..43ba908b 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc -from llmtuner.tuner import get_train_args, run_sft +from llmtuner.tuner import run_exp from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import format_info, get_eval_results @@ -105,6 +105,7 @@ class Runner: checkpoint_dir = None args = dict( + stage="sft", model_name_or_path=model_name_or_path, do_train=True, overwrite_cache=True, @@ -141,16 +142,8 @@ class Runner: args["eval_steps"] = save_steps args["load_best_model_at_end"] = True - model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) - - run_args = dict( - model_args=model_args, - data_args=data_args, - training_args=training_args, - finetuning_args=finetuning_args, - callbacks=[trainer_callback] - ) - thread = threading.Thread(target=run_sft, kwargs=run_args) + run_kwargs = dict(args=args, callbacks=[trainer_callback]) + thread = threading.Thread(target=run_exp, kwargs=run_kwargs) thread.start() while thread.is_alive(): @@ -158,7 +151,7 @@ class Runner: if self.aborted: yield ALERTS["info_aborting"][lang] else: - yield format_info(logger_handler.log, trainer_callback.tracker) + yield format_info(logger_handler.log, trainer_callback) yield self.finalize(lang) @@ -194,6 +187,7 @@ class Runner: output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") args = dict( + stage="sft", model_name_or_path=model_name_or_path, do_eval=True, overwrite_cache=True, @@ -216,16 +210,8 @@ class Runner: args.pop("do_eval", None) args["do_predict"] = True - model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) - - run_args = dict( - model_args=model_args, - data_args=data_args, - training_args=training_args, - finetuning_args=finetuning_args, - callbacks=[trainer_callback] - ) - thread = threading.Thread(target=run_sft, kwargs=run_args) + run_kwargs = dict(args=args, callbacks=[trainer_callback]) + thread = threading.Thread(target=run_exp, kwargs=run_kwargs) thread.start() while thread.is_alive(): @@ -233,6 +219,6 @@ class Runner: if self.aborted: yield ALERTS["info_aborting"][lang] else: - yield format_info(logger_handler.log, trainer_callback.tracker) + yield format_info(logger_handler.log, trainer_callback) yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 4921195d..8a4f8d5a 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,20 +3,23 @@ import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import Any, Dict, Generator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple from datetime import datetime from llmtuner.extras.ploting import smooth -from llmtuner.tuner import get_infer_args, load_model_and_tokenizer +from llmtuner.tuner import export_model from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG from llmtuner.webui.locales import ALERTS +if TYPE_CHECKING: + from llmtuner.extras.callbacks import LogCallback -def format_info(log: str, tracker: dict) -> str: + +def format_info(log: str, callback: "LogCallback") -> str: info = log - if "current_steps" in tracker: + if callback.max_steps: info += "Running **{:d}/{:d}**: {} < {}\n".format( - tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"] + callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time ) return info @@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl return fig -def export_model( +def save_model( lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str ) -> Generator[str, None, None]: if not model_name: @@ -114,12 +117,10 @@ def export_model( args = dict( model_name_or_path=model_name_or_path, checkpoint_dir=checkpoint_dir, - finetuning_type=finetuning_type + finetuning_type=finetuning_type, + output_dir=save_dir ) yield ALERTS["info_exporting"][lang] - model_args, _, finetuning_args, _ = get_infer_args(args) - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB") - tokenizer.save_pretrained(save_dir) + export_model(args, max_shard_size="{}GB".format(max_shard_size)) yield ALERTS["info_exported"][lang] diff --git a/src/train_bash.py b/src/train_bash.py index ec10deaa..0facef9a 100644 --- a/src/train_bash.py +++ b/src/train_bash.py @@ -1,17 +1,7 @@ -from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo - +from llmtuner import run_exp def main(): - model_args, data_args, training_args, finetuning_args, general_args = get_train_args() - - if general_args.stage == "pt": - run_pt(model_args, data_args, training_args, finetuning_args) - elif general_args.stage == "sft": - run_sft(model_args, data_args, training_args, finetuning_args) - elif general_args.stage == "rm": - run_rm(model_args, data_args, training_args, finetuning_args) - elif general_args.stage == "ppo": - run_ppo(model_args, data_args, training_args, finetuning_args) + run_exp() def _mp_fn(index): diff --git a/src/train_web.py b/src/train_web.py index 15bc8ab0..38efd64d 100644 --- a/src/train_web.py +++ b/src/train_web.py @@ -1,4 +1,4 @@ -from llmtuner.webui.interface import create_ui +from llmtuner import create_ui def main(): diff --git a/src/web_demo.py b/src/web_demo.py index bc79375a..112bac8b 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -5,17 +5,14 @@ import gradio as gr from transformers.utils.versions import require_version -from llmtuner.tuner import get_infer_args -from llmtuner.webui.chat import WebChatModel -from llmtuner.webui.components.chatbot import create_chat_box -from llmtuner.webui.manager import Manager +from llmtuner import Manager, WebChatModel, create_chat_box require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") def main(): - chat_model = WebChatModel(*get_infer_args()) + chat_model = WebChatModel() with gr.Blocks(title="Web Demo") as demo: lang = gr.Dropdown(choices=["en", "zh"], value="en")