modify code structure

Former-commit-id: 08f180e78862cad902b6cdbbd8c86e39b5cacf8a
This commit is contained in:
hiyouga 2023-08-02 23:17:36 +08:00
parent 4b8e4398bc
commit 4242897b78
25 changed files with 188 additions and 145 deletions

View File

@ -5,13 +5,11 @@
import uvicorn import uvicorn
from llmtuner import ChatModel from llmtuner import ChatModel, create_app
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@ -3,11 +3,10 @@
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
history = [] history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

View File

@ -2,15 +2,11 @@
# Exports the fine-tuned model. # Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner.tuner import get_train_args, load_model_and_tokenizer from llmtuner import export_model
def main(): def main():
model_args, _, training_args, finetuning_args, _ = get_train_args() export_model()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,4 +1,9 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import Manager, WebChatModel, create_ui, create_chat_box
__version__ = "0.1.5" __version__ = "0.1.5"

View File

@ -0,0 +1 @@
from llmtuner.api.app import create_app

View File

@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from typing import List, Tuple from typing import List, Tuple
from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role, Role,
Finish, Finish,
@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@ -1,30 +1,21 @@
import torch import torch
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel: class ChatModel:
def __init__( def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self, model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template(data_args.template) self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.generating_args = generating_args
def process_args( def process_args(
self, self,

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
@ -14,58 +15,105 @@ class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
self.runner = runner self.runner = runner
self.in_training = False
self.start_time = time.time() self.start_time = time.time()
self.tracker = {} self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time() self.start_time = time.time()
self.max_steps = state.max_steps
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step Event called at the end of training.
might take several inputs.
""" """
if self.runner is not None and self.runner.aborted: if state.is_local_process_zero:
control.should_epoch_stop = True self.in_training = False
control.should_training_stop = True self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted: if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero: if not state.is_local_process_zero:
return return
cur_time = time.time() logs = dict(
cur_steps = state.log_history[-1].get("step") current_steps=self.cur_steps,
elapsed_time = cur_time - self.start_time total_steps=self.max_steps,
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 loss=state.log_history[-1].get("loss", None),
remaining_steps = state.max_steps - cur_steps eval_loss=state.log_history[-1].get("eval_loss", None),
remaining_time = remaining_steps * avg_time_per_step predict_loss=state.log_history[-1].get("predict_loss", None),
self.tracker = { reward=state.log_history[-1].get("reward", None),
"current_steps": cur_steps, learning_rate=state.log_history[-1].get("learning_rate", None),
"total_steps": state.max_steps, epoch=state.log_history[-1].get("epoch", None),
"loss": state.log_history[-1].get("loss", None), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
"eval_loss": state.log_history[-1].get("eval_loss", None), elapsed_time=self.elapsed_time,
"predict_loss": state.log_history[-1].get("predict_loss", None), remaining_time=self.remaining_time
"reward": state.log_history[-1].get("reward", None), )
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n") f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@ -1,5 +1 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer from llmtuner.tuner.tune import export_model, run_exp
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo

View File

@ -11,7 +11,6 @@ from trl.core import LengthSampler
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
@ -90,14 +89,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
reward_meter = AverageMeter() reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control) self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter) batch = next(dataiter)
steps_trained += 1 steps_trained += 1
# Cast to inference mode # Cast to inference mode
unwrapped_model.gradient_checkpointing_disable() unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True unwrapped_model.config.use_cache = True
unwrapped_model.eval()
# Get inputs # Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
@ -106,21 +104,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Cast to training mode # Cast to training mode
unwrapped_model.gradient_checkpointing_enable() unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False unwrapped_model.config.use_cache = False
unwrapped_model.train()
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict( logs = dict(
loss=round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2) epoch=round(step / len_dataloader, 2)
) )
print(logs) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
self.state.log_history.append(logs) self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control) self.log_callback.on_log(self.args, self.state, self.control)
@ -137,10 +137,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0 steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
@torch.no_grad() @torch.no_grad()
def get_inputs( def get_inputs(
self, self,
inputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None, length_sampler: Optional[Callable] = None,
**generation_kwargs **generation_kwargs
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
@ -152,7 +154,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.model, layer_norm_params = cast_layernorm_dtype(self.model) self.model, layer_norm_params = cast_layernorm_dtype(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs) response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
@ -161,7 +163,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
unwrapped_model.pretrained_model.generation_config._from_model_config = False unwrapped_model.pretrained_model.generation_config._from_model_config = False
queries, responses = [], [] queries, responses = [], []
query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu() query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
@ -181,11 +183,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
Computes scores using given reward model. Computes scores using given reward model.
""" """
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
_, _, values = self.model( batch = self.prepare_model_inputs(queries, responses)
**self.prepare_model_inputs(queries, responses), _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
return rewards return rewards

View File

@ -10,7 +10,6 @@ from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
@ -25,7 +24,7 @@ def run_ppo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")

View File

@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
@ -21,7 +20,7 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")

View File

@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
@ -22,7 +21,7 @@ def run_rm(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")

View File

@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
@ -22,7 +21,7 @@ def run_sft(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")

View File

@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
if TYPE_CHECKING:
from transformers import TrainerCallback
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
run_exp()

View File

@ -0,0 +1,4 @@
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.interface import create_ui
from llmtuner.webui.manager import Manager
from llmtuner.webui.components import create_chat_box

View File

@ -1,22 +1,21 @@
import os import os
from typing import List, Tuple from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, *args): def __init__(self, args: Optional[Dict[str, Any]]) -> None:
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
if len(args) != 0: if args is not None:
super().__init__(*args) super().__init__(args)
def load_model( def load_model(
self, self,
@ -57,7 +56,7 @@ class WebChatModel(ChatModel):
template=template, template=template,
source_prefix=source_prefix source_prefix=source_prefix
) )
super().__init__(*get_infer_args(args)) super().__init__(args)
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]

View File

@ -3,3 +3,4 @@ from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from llmtuner.webui.utils import export_model from llmtuner.webui.utils import save_model
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@ -16,7 +16,7 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click( export_btn.click(
export_model, save_model,
[ [
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],

View File

@ -47,6 +47,7 @@ def create_ui() -> gr.Blocks:
manager.gen_label, manager.gen_label,
[top_elems["lang"]], [top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()], [elem for elems in elem_list for elem in elems.values()],
queue=False
) )
return demo return demo

View File

@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results from llmtuner.webui.utils import format_info, get_eval_results
@ -105,6 +105,7 @@ class Runner:
checkpoint_dir = None checkpoint_dir = None
args = dict( args = dict(
stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
@ -141,16 +142,8 @@ class Runner:
args["eval_steps"] = save_steps args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start() thread.start()
while thread.is_alive(): while thread.is_alive():
@ -158,7 +151,7 @@ class Runner:
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang] yield ALERTS["info_aborting"][lang]
else: else:
yield format_info(logger_handler.log, trainer_callback.tracker) yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang) yield self.finalize(lang)
@ -194,6 +187,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict( args = dict(
stage="sft",
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_eval=True, do_eval=True,
overwrite_cache=True, overwrite_cache=True,
@ -216,16 +210,8 @@ class Runner:
args.pop("do_eval", None) args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) run_kwargs = dict(args=args, callbacks=[trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
run_args = dict(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start() thread.start()
while thread.is_alive(): while thread.is_alive():
@ -233,6 +219,6 @@ class Runner:
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang] yield ALERTS["info_aborting"][lang]
else: else:
yield format_info(logger_handler.log, trainer_callback.tracker) yield format_info(logger_handler.log, trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))

View File

@ -3,20 +3,23 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Any, Dict, Generator, List, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
def format_info(log: str, callback: "LogCallback") -> str:
info = log info = log
if "current_steps" in tracker: if callback.max_steps:
info += "Running **{:d}/{:d}**: {} < {}\n".format( info += "Running **{:d}/{:d}**: {} < {}\n".format(
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"] callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
) )
return info return info
@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
return fig return fig
def export_model( def save_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if not model_name: if not model_name:
@ -114,12 +117,10 @@ def export_model(
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type finetuning_type=finetuning_type,
output_dir=save_dir
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args) export_model(args, max_shard_size="{}GB".format(max_shard_size))
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang] yield ALERTS["info_exported"][lang]

View File

@ -1,17 +1,7 @@
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo from llmtuner import run_exp
def main(): def main():
model_args, data_args, training_args, finetuning_args, general_args = get_train_args() run_exp()
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args)
def _mp_fn(index): def _mp_fn(index):

View File

@ -1,4 +1,4 @@
from llmtuner.webui.interface import create_ui from llmtuner import create_ui
def main(): def main():

View File

@ -5,17 +5,14 @@
import gradio as gr import gradio as gr
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner.tuner import get_infer_args from llmtuner import Manager, WebChatModel, create_chat_box
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def main(): def main():
chat_model = WebChatModel(*get_infer_args()) chat_model = WebChatModel()
with gr.Blocks(title="Web Demo") as demo: with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en") lang = gr.Dropdown(choices=["en", "zh"], value="en")