From 64b4f71673f69ff18bcf9a6107c59b302db41121 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 20 Jul 2023 15:08:57 +0800 Subject: [PATCH] simplify code Former-commit-id: 67a27730744b71795b10260d050501bfe2329c26 --- src/api_demo.py | 10 +++-- src/cli_demo.py | 3 +- src/export_model.py | 2 +- src/llmtuner/__init__.py | 3 -- src/llmtuner/api/__init__.py | 1 - src/llmtuner/chat/stream_chat.py | 4 +- src/llmtuner/dsets/__init__.py | 1 + src/llmtuner/dsets/callbacks.py | 63 ------------------------------ src/llmtuner/dsets/preprocess.py | 4 +- src/llmtuner/dsets/utils.py | 16 ++++++++ src/llmtuner/extras/template.py | 32 ++++++--------- src/llmtuner/tuner/pt/workflow.py | 14 +------ src/llmtuner/tuner/rm/workflow.py | 14 +------ src/llmtuner/tuner/sft/workflow.py | 14 +------ src/llmtuner/webui/__init__.py | 1 - src/train_bash.py | 2 +- src/train_web.py | 2 +- src/web_demo.py | 2 +- 18 files changed, 52 insertions(+), 136 deletions(-) delete mode 100644 src/llmtuner/dsets/callbacks.py create mode 100644 src/llmtuner/dsets/utils.py diff --git a/src/api_demo.py b/src/api_demo.py index 3041b2e1..45f64faf 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -5,9 +5,13 @@ import uvicorn -from llmtuner import create_app +from llmtuner.api.app import create_app + + +def main(): + app = create_app() + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) if __name__ == "__main__": - app = create_app() - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + main() diff --git a/src/cli_demo.py b/src/cli_demo.py index dd91da77..80e61264 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -2,7 +2,8 @@ # Implements stream chat in command line for fine-tuned models. # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint -from llmtuner import ChatModel, get_infer_args +from llmtuner import ChatModel +from llmtuner.tuner import get_infer_args def main(): diff --git a/src/export_model.py b/src/export_model.py index 3c1ffbbb..9e5fe242 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -2,7 +2,7 @@ # Exports the fine-tuned model. # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model -from llmtuner import get_train_args, load_model_and_tokenizer +from llmtuner.tuner import get_train_args, load_model_and_tokenizer def main(): diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index fde74f57..67fc38e7 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,7 +1,4 @@ -from llmtuner.api import create_app from llmtuner.chat import ChatModel -from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo -from llmtuner.webui import create_ui __version__ = "0.1.1" diff --git a/src/llmtuner/api/__init__.py b/src/llmtuner/api/__init__.py index b3ce183a..e69de29b 100644 --- a/src/llmtuner/api/__init__.py +++ b/src/llmtuner/api/__init__.py @@ -1 +0,0 @@ -from llmtuner.api.app import create_app diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index cedfad7a..42126eea 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -4,7 +4,7 @@ from threading import Thread from transformers import TextIteratorStreamer from llmtuner.extras.misc import get_logits_processor -from llmtuner.extras.template import Template +from llmtuner.extras.template import get_template from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments from llmtuner.tuner import load_model_and_tokenizer @@ -19,7 +19,7 @@ class ChatModel: generating_args: GeneratingArguments ) -> None: self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - self.template = Template(data_args.prompt_template) + self.template = get_template(data_args.prompt_template) self.source_prefix = data_args.source_prefix if data_args.source_prefix else "" self.generating_args = generating_args diff --git a/src/llmtuner/dsets/__init__.py b/src/llmtuner/dsets/__init__.py index 7667c89c..cccbd745 100644 --- a/src/llmtuner/dsets/__init__.py +++ b/src/llmtuner/dsets/__init__.py @@ -1,2 +1,3 @@ from llmtuner.dsets.loader import get_dataset from llmtuner.dsets.preprocess import preprocess_dataset +from llmtuner.dsets.utils import split_dataset diff --git a/src/llmtuner/dsets/callbacks.py b/src/llmtuner/dsets/callbacks.py deleted file mode 100644 index cb013961..00000000 --- a/src/llmtuner/dsets/callbacks.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import json -import time -from datetime import timedelta - -from transformers import ( - TrainerCallback, - TrainerControl, - TrainerState, - TrainingArguments -) - - -class LogCallback(TrainerCallback): - - def __init__(self, runner=None): - self.runner = runner - self.start_time = time.time() - self.tracker = {} - - def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - r""" - Event called at the beginning of a training step. If using gradient accumulation, one training step - might take several inputs. - """ - if self.runner is not None and self.runner.aborted: - control.should_epoch_stop = True - control.should_training_stop = True - - def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): - r""" - Event called at the end of an substep during gradient accumulation. - """ - if self.runner is not None and self.runner.aborted: - control.should_epoch_stop = True - control.should_training_stop = True - - def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: - r""" - Event called after logging the last logs. - """ - if "loss" not in state.log_history[-1]: - return - cur_time = time.time() - cur_steps = state.log_history[-1].get("step") - elapsed_time = cur_time - self.start_time - avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 - remaining_steps = state.max_steps - cur_steps - remaining_time = remaining_steps * avg_time_per_step - self.tracker = { - "current_steps": cur_steps, - "total_steps": state.max_steps, - "loss": state.log_history[-1].get("loss", None), - "reward": state.log_history[-1].get("reward", None), - "learning_rate": state.log_history[-1].get("learning_rate", None), - "epoch": state.log_history[-1].get("epoch", None), - "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, - "elapsed_time": str(timedelta(seconds=int(elapsed_time))), - "remaining_time": str(timedelta(seconds=int(remaining_time))) - } - os.makedirs(args.output_dir, exist_ok=True) - with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: - f.write(json.dumps(self.tracker) + "\n") diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 4eb912c1..bf65cc7d 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -6,7 +6,7 @@ from transformers.tokenization_utils import PreTrainedTokenizer from datasets import Dataset from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.extras.template import Template +from llmtuner.extras.template import get_template from llmtuner.hparams import DataArguments @@ -19,7 +19,7 @@ def preprocess_dataset( ) -> Dataset: column_names = list(dataset.column_names) - prompt_template = Template(data_args.prompt_template) + prompt_template = get_template(data_args.prompt_template) # support question with a single answer or multiple answers def get_dialog(examples): diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/dsets/utils.py new file mode 100644 index 00000000..64436e70 --- /dev/null +++ b/src/llmtuner/dsets/utils.py @@ -0,0 +1,16 @@ +from typing import Dict +from datasets import Dataset + + +def split_dataset( + dataset: Dataset, dev_ratio: float, do_train: bool +) -> Dict[str, Dataset]: + # Split the dataset + if do_train: + if dev_ratio > 1e-6: + dataset = dataset.train_test_split(test_size=dev_ratio) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + return {"train_dataset": dataset} + else: # do_eval or do_predict + return {"eval_dataset": dataset} diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index b41e3398..bb331424 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -3,30 +3,13 @@ from dataclasses import dataclass @dataclass -class Format: +class Template: + prefix: str prompt: str sep: str use_history: bool - -templates: Dict[str, Format] = {} - - -@dataclass -class Template: - - name: str - - def __post_init__(self): - if self.name in templates: - self.prefix = templates[self.name].prefix - self.prompt = templates[self.name].prompt - self.sep = templates[self.name].sep - self.use_history = templates[self.name].use_history - else: - raise ValueError("Template {} does not exist.".format(self.name)) - def get_prompt( self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" ) -> str: @@ -61,8 +44,11 @@ class Template: return convs[:-1] # drop last +templates: Dict[str, Template] = {} + + def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: - templates[name] = Format( + templates[name] = Template( prefix=prefix, prompt=prompt, sep=sep, @@ -70,6 +56,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history ) +def get_template(name: str) -> Template: + template = templates.get(name, None) + assert template is not None, "Template {} does not exist.".format(name) + return template + + r""" Supports language model inference without histories. """ diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/tuner/pt/workflow.py index 1837e366..59813532 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/tuner/pt/workflow.py @@ -4,7 +4,7 @@ import math from typing import Optional, List from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.ploting import plot_loss @@ -28,16 +28,6 @@ def run_pt( label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id ) - # Split the dataset - if training_args.do_train: - if data_args.dev_ratio > 1e-6: - dataset = dataset.train_test_split(test_size=data_args.dev_ratio) - trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} - else: - trainer_kwargs = {"train_dataset": dataset} - else: # do_eval or do_predict - trainer_kwargs = {"eval_dataset": dataset} - # Initialize our Trainer trainer = PeftTrainer( finetuning_args=finetuning_args, @@ -46,7 +36,7 @@ def run_pt( tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, - **trainer_kwargs + **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) ) # Training diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/tuner/rm/workflow.py index cc0835ad..c2d7104a 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/tuner/rm/workflow.py @@ -5,7 +5,7 @@ from typing import Optional, List from transformers import Seq2SeqTrainingArguments, TrainerCallback -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.ploting import plot_loss from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments @@ -29,16 +29,6 @@ def run_rm( training_args.remove_unused_columns = False # important for pairwise dataset - # Split the dataset - if training_args.do_train: - if data_args.dev_ratio > 1e-6: - dataset = dataset.train_test_split(test_size=data_args.dev_ratio) - trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} - else: - trainer_kwargs = {"train_dataset": dataset} - else: # do_eval or do_predict - trainer_kwargs = {"eval_dataset": dataset} - # Initialize our Trainer trainer = PairwisePeftTrainer( finetuning_args=finetuning_args, @@ -48,7 +38,7 @@ def run_rm( data_collator=data_collator, callbacks=callbacks, compute_metrics=compute_accuracy, - **trainer_kwargs + **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) ) # Training diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 08889796..6ba2f621 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -3,7 +3,7 @@ from typing import Optional, List from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor @@ -35,16 +35,6 @@ def run_sft( training_args.generation_num_beams = data_args.eval_num_beams if \ data_args.eval_num_beams is not None else training_args.generation_num_beams - # Split the dataset - if training_args.do_train: - if data_args.dev_ratio > 1e-6: - dataset = dataset.train_test_split(test_size=data_args.dev_ratio) - trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} - else: - trainer_kwargs = {"train_dataset": dataset} - else: # do_eval or do_predict - trainer_kwargs = {"eval_dataset": dataset} - # Initialize our Trainer trainer = Seq2SeqPeftTrainer( finetuning_args=finetuning_args, @@ -54,7 +44,7 @@ def run_sft( data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, - **trainer_kwargs + **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) ) # Keyword arguments for `model.generate` diff --git a/src/llmtuner/webui/__init__.py b/src/llmtuner/webui/__init__.py index 686cc95f..e69de29b 100644 --- a/src/llmtuner/webui/__init__.py +++ b/src/llmtuner/webui/__init__.py @@ -1 +0,0 @@ -from llmtuner.webui.interface import create_ui diff --git a/src/train_bash.py b/src/train_bash.py index 291c3cf0..ec10deaa 100644 --- a/src/train_bash.py +++ b/src/train_bash.py @@ -1,4 +1,4 @@ -from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo +from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo def main(): diff --git a/src/train_web.py b/src/train_web.py index 3f7855c0..cdc7d603 100644 --- a/src/train_web.py +++ b/src/train_web.py @@ -1,4 +1,4 @@ -from llmtuner import create_ui +from llmtuner.webui.interface import create_ui def main(): diff --git a/src/web_demo.py b/src/web_demo.py index 682034bc..9ec5a38f 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -5,7 +5,7 @@ import gradio as gr from transformers.utils.versions import require_version -from llmtuner import get_infer_args +from llmtuner.tuner import get_infer_args from llmtuner.webui.chat import WebChatModel from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.manager import Manager