mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
modify code structure
Former-commit-id: 08f180e78862cad902b6cdbbd8c86e39b5cacf8a
This commit is contained in:
parent
4b8e4398bc
commit
4242897b78
@ -5,13 +5,11 @@
|
|||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel, create_app
|
||||||
from llmtuner.api.app import create_app
|
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|
||||||
|
@ -3,11 +3,10 @@
|
|||||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
history = []
|
history = []
|
||||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||||
|
|
||||||
|
@ -2,15 +2,11 @@
|
|||||||
# Exports the fine-tuned model.
|
# Exports the fine-tuned model.
|
||||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||||
|
|
||||||
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
from llmtuner import export_model
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_args, _, training_args, finetuning_args, _ = get_train_args()
|
export_model()
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
|
||||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,4 +1,9 @@
|
|||||||
|
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
||||||
|
|
||||||
|
from llmtuner.api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
|
from llmtuner.tuner import export_model, run_exp
|
||||||
|
from llmtuner.webui import Manager, WebChatModel, create_ui, create_chat_box
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
__version__ = "0.1.5"
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
from llmtuner.api.app import create_app
|
@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
|
|||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
Role,
|
Role,
|
||||||
Finish,
|
Finish,
|
||||||
@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
@ -1,30 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self,
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
model_args: "ModelArguments",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
generating_args: "GeneratingArguments"
|
|
||||||
) -> None:
|
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template(data_args.template)
|
self.template = get_template(data_args.template)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.generating_args = generating_args
|
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self,
|
self,
|
||||||
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
from transformers.trainer_utils import has_length
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
@ -14,58 +15,105 @@ class LogCallback(TrainerCallback):
|
|||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.in_training = False
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.tracker = {}
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
self.elapsed_time = ""
|
||||||
|
self.remaining_time = ""
|
||||||
|
|
||||||
|
def timing(self):
|
||||||
|
cur_time = time.time()
|
||||||
|
elapsed_time = cur_time - self.start_time
|
||||||
|
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||||
|
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||||
|
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||||
|
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||||
|
|
||||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
"""
|
"""
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.in_training = True
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
self.max_steps = state.max_steps
|
||||||
|
|
||||||
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
Event called at the end of training.
|
||||||
might take several inputs.
|
|
||||||
"""
|
"""
|
||||||
if self.runner is not None and self.runner.aborted:
|
if state.is_local_process_zero:
|
||||||
control.should_epoch_stop = True
|
self.in_training = False
|
||||||
control.should_training_stop = True
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of an substep during gradient accumulation.
|
Event called at the end of an substep during gradient accumulation.
|
||||||
"""
|
"""
|
||||||
|
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of a training step.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.cur_steps = state.global_step
|
||||||
|
self.timing()
|
||||||
if self.runner is not None and self.runner.aborted:
|
if self.runner is not None and self.runner.aborted:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after an evaluation phase.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero and not self.in_training:
|
||||||
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after a successful prediction.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero and not self.in_training:
|
||||||
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
if not state.is_world_process_zero:
|
if not state.is_local_process_zero:
|
||||||
return
|
return
|
||||||
|
|
||||||
cur_time = time.time()
|
logs = dict(
|
||||||
cur_steps = state.log_history[-1].get("step")
|
current_steps=self.cur_steps,
|
||||||
elapsed_time = cur_time - self.start_time
|
total_steps=self.max_steps,
|
||||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
loss=state.log_history[-1].get("loss", None),
|
||||||
remaining_steps = state.max_steps - cur_steps
|
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||||
remaining_time = remaining_steps * avg_time_per_step
|
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||||
self.tracker = {
|
reward=state.log_history[-1].get("reward", None),
|
||||||
"current_steps": cur_steps,
|
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||||
"total_steps": state.max_steps,
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
"loss": state.log_history[-1].get("loss", None),
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
elapsed_time=self.elapsed_time,
|
||||||
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
remaining_time=self.remaining_time
|
||||||
"reward": state.log_history[-1].get("reward", None),
|
)
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
||||||
"epoch": state.log_history[-1].get("epoch", None),
|
|
||||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
||||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
||||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
||||||
}
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(self.tracker) + "\n")
|
f.write(json.dumps(logs) + "\n")
|
||||||
|
|
||||||
|
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after a prediction step.
|
||||||
|
"""
|
||||||
|
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||||
|
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||||
|
if self.max_steps == 0:
|
||||||
|
self.max_steps = len(eval_dataloader)
|
||||||
|
self.cur_steps += 1
|
||||||
|
self.timing()
|
||||||
|
@ -1,5 +1 @@
|
|||||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner.tune import export_model, run_exp
|
||||||
from llmtuner.tuner.pt import run_pt
|
|
||||||
from llmtuner.tuner.sft import run_sft
|
|
||||||
from llmtuner.tuner.rm import run_rm
|
|
||||||
from llmtuner.tuner.ppo import run_ppo
|
|
||||||
|
@ -11,7 +11,6 @@ from trl.core import LengthSampler
|
|||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
|
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||||
|
|
||||||
@ -90,14 +89,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
reward_meter = AverageMeter()
|
reward_meter = AverageMeter()
|
||||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||||
batch = next(dataiter)
|
batch = next(dataiter)
|
||||||
steps_trained += 1
|
steps_trained += 1
|
||||||
|
|
||||||
# Cast to inference mode
|
# Cast to inference mode
|
||||||
unwrapped_model.gradient_checkpointing_disable()
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
unwrapped_model.config.use_cache = True
|
unwrapped_model.config.use_cache = True
|
||||||
unwrapped_model.eval()
|
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||||
@ -106,21 +104,23 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
# Cast to training mode
|
# Cast to training mode
|
||||||
unwrapped_model.gradient_checkpointing_enable()
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
unwrapped_model.config.use_cache = False
|
unwrapped_model.config.use_cache = False
|
||||||
unwrapped_model.train()
|
|
||||||
|
|
||||||
# Run PPO step
|
# Run PPO step
|
||||||
stats = self.step(queries, responses, rewards)
|
stats = self.step(queries, responses, rewards)
|
||||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
self.state.global_step += 1
|
||||||
|
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||||
logs = dict(
|
logs = dict(
|
||||||
loss=round(loss_meter.avg, 4),
|
loss=round(loss_meter.avg, 4),
|
||||||
reward=round(reward_meter.avg, 4),
|
reward=round(reward_meter.avg, 4),
|
||||||
learning_rate=stats["ppo/learning_rate"],
|
learning_rate=stats["ppo/learning_rate"],
|
||||||
epoch=round(step / len_dataloader, 2)
|
epoch=round(step / len_dataloader, 2)
|
||||||
)
|
)
|
||||||
print(logs)
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
self.state.log_history.append(logs)
|
self.state.log_history.append(logs)
|
||||||
self.log_callback.on_log(self.args, self.state, self.control)
|
self.log_callback.on_log(self.args, self.state, self.control)
|
||||||
@ -137,10 +137,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
|
|
||||||
|
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
length_sampler: Optional[Callable] = None,
|
length_sampler: Optional[Callable] = None,
|
||||||
**generation_kwargs
|
**generation_kwargs
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
@ -152,7 +154,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
|
|
||||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs)
|
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||||
|
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
@ -161,7 +163,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||||
|
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
query, response = inputs["input_ids"].detach().cpu(), response[:, inputs["input_ids"].size(-1):].detach().cpu()
|
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
@ -181,11 +183,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
"""
|
"""
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
_, _, values = self.model(
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
**self.prepare_model_inputs(queries, responses),
|
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True
|
|
||||||
)
|
|
||||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
return rewards
|
return rewards
|
||||||
|
@ -10,7 +10,6 @@ from transformers import DataCollatorForSeq2Seq
|
|||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||||
@ -25,7 +24,7 @@ def run_ppo(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||||
|
@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional, List
|
|||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
@ -21,7 +20,7 @@ def run_pt(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||||
@ -22,7 +21,7 @@ def run_rm(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||||
|
@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, List
|
|||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
@ -22,7 +21,7 @@ def run_sft(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
|
36
src/llmtuner/tuner/tune.py
Normal file
36
src/llmtuner/tuner/tune.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.pt import run_pt
|
||||||
|
from llmtuner.tuner.sft import run_sft
|
||||||
|
from llmtuner.tuner.rm import run_rm
|
||||||
|
from llmtuner.tuner.ppo import run_ppo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||||
|
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
|
||||||
|
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||||
|
|
||||||
|
if general_args.stage == "pt":
|
||||||
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif general_args.stage == "sft":
|
||||||
|
run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif general_args.stage == "rm":
|
||||||
|
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif general_args.stage == "ppo":
|
||||||
|
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||||
|
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||||
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_exp()
|
@ -0,0 +1,4 @@
|
|||||||
|
from llmtuner.webui.chat import WebChatModel
|
||||||
|
from llmtuner.webui.interface import create_ui
|
||||||
|
from llmtuner.webui.manager import Manager
|
||||||
|
from llmtuner.webui.components import create_chat_box
|
@ -1,22 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.stream_chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from llmtuner.hparams import GeneratingArguments
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, args: Optional[Dict[str, Any]]) -> None:
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.generating_args = GeneratingArguments()
|
||||||
if len(args) != 0:
|
if args is not None:
|
||||||
super().__init__(*args)
|
super().__init__(args)
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
@ -57,7 +56,7 @@ class WebChatModel(ChatModel):
|
|||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix
|
source_prefix=source_prefix
|
||||||
)
|
)
|
||||||
super().__init__(*get_infer_args(args))
|
super().__init__(args)
|
||||||
|
|
||||||
yield ALERTS["info_loaded"][lang]
|
yield ALERTS["info_loaded"][lang]
|
||||||
|
|
||||||
|
@ -3,3 +3,4 @@ from llmtuner.webui.components.sft import create_sft_tab
|
|||||||
from llmtuner.webui.components.eval import create_eval_tab
|
from llmtuner.webui.components.eval import create_eval_tab
|
||||||
from llmtuner.webui.components.infer import create_infer_tab
|
from llmtuner.webui.components.infer import create_infer_tab
|
||||||
from llmtuner.webui.components.export import create_export_tab
|
from llmtuner.webui.components.export import create_export_tab
|
||||||
|
from llmtuner.webui.components.chatbot import create_chat_box
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from llmtuner.webui.utils import export_model
|
from llmtuner.webui.utils import save_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@ -16,7 +16,7 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
|||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
|
||||||
export_btn.click(
|
export_btn.click(
|
||||||
export_model,
|
save_model,
|
||||||
[
|
[
|
||||||
top_elems["lang"],
|
top_elems["lang"],
|
||||||
top_elems["model_name"],
|
top_elems["model_name"],
|
||||||
|
@ -47,6 +47,7 @@ def create_ui() -> gr.Blocks:
|
|||||||
manager.gen_label,
|
manager.gen_label,
|
||||||
[top_elems["lang"]],
|
[top_elems["lang"]],
|
||||||
[elem for elems in elem_list for elem in elems.values()],
|
[elem for elems in elem_list for elem in elems.values()],
|
||||||
|
queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
@ -9,7 +9,7 @@ from llmtuner.extras.callbacks import LogCallback
|
|||||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import get_train_args, run_sft
|
from llmtuner.tuner import run_exp
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import format_info, get_eval_results
|
from llmtuner.webui.utils import format_info, get_eval_results
|
||||||
@ -105,6 +105,7 @@ class Runner:
|
|||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
|
stage="sft",
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_train=True,
|
do_train=True,
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
@ -141,16 +142,8 @@ class Runner:
|
|||||||
args["eval_steps"] = save_steps
|
args["eval_steps"] = save_steps
|
||||||
args["load_best_model_at_end"] = True
|
args["load_best_model_at_end"] = True
|
||||||
|
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||||
|
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
run_args = dict(
|
|
||||||
model_args=model_args,
|
|
||||||
data_args=data_args,
|
|
||||||
training_args=training_args,
|
|
||||||
finetuning_args=finetuning_args,
|
|
||||||
callbacks=[trainer_callback]
|
|
||||||
)
|
|
||||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
@ -158,7 +151,7 @@ class Runner:
|
|||||||
if self.aborted:
|
if self.aborted:
|
||||||
yield ALERTS["info_aborting"][lang]
|
yield ALERTS["info_aborting"][lang]
|
||||||
else:
|
else:
|
||||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
yield format_info(logger_handler.log, trainer_callback)
|
||||||
|
|
||||||
yield self.finalize(lang)
|
yield self.finalize(lang)
|
||||||
|
|
||||||
@ -194,6 +187,7 @@ class Runner:
|
|||||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
|
stage="sft",
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
do_eval=True,
|
do_eval=True,
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
@ -216,16 +210,8 @@ class Runner:
|
|||||||
args.pop("do_eval", None)
|
args.pop("do_eval", None)
|
||||||
args["do_predict"] = True
|
args["do_predict"] = True
|
||||||
|
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
run_kwargs = dict(args=args, callbacks=[trainer_callback])
|
||||||
|
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
run_args = dict(
|
|
||||||
model_args=model_args,
|
|
||||||
data_args=data_args,
|
|
||||||
training_args=training_args,
|
|
||||||
finetuning_args=finetuning_args,
|
|
||||||
callbacks=[trainer_callback]
|
|
||||||
)
|
|
||||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
@ -233,6 +219,6 @@ class Runner:
|
|||||||
if self.aborted:
|
if self.aborted:
|
||||||
yield ALERTS["info_aborting"][lang]
|
yield ALERTS["info_aborting"][lang]
|
||||||
else:
|
else:
|
||||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
yield format_info(logger_handler.log, trainer_callback)
|
||||||
|
|
||||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||||
|
@ -3,20 +3,23 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from typing import Any, Dict, Generator, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner import export_model
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
|
||||||
def format_info(log: str, tracker: dict) -> str:
|
|
||||||
|
def format_info(log: str, callback: "LogCallback") -> str:
|
||||||
info = log
|
info = log
|
||||||
if "current_steps" in tracker:
|
if callback.max_steps:
|
||||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||||
)
|
)
|
||||||
return info
|
return info
|
||||||
|
|
||||||
@ -87,7 +90,7 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def export_model(
|
def save_model(
|
||||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if not model_name:
|
if not model_name:
|
||||||
@ -114,12 +117,10 @@ def export_model(
|
|||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type
|
finetuning_type=finetuning_type,
|
||||||
|
output_dir=save_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
|
||||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
|
||||||
tokenizer.save_pretrained(save_dir)
|
|
||||||
yield ALERTS["info_exported"][lang]
|
yield ALERTS["info_exported"][lang]
|
||||||
|
@ -1,17 +1,7 @@
|
|||||||
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
from llmtuner import run_exp
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
|
run_exp()
|
||||||
|
|
||||||
if general_args.stage == "pt":
|
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "sft":
|
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "rm":
|
|
||||||
run_rm(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "ppo":
|
|
||||||
run_ppo(model_args, data_args, training_args, finetuning_args)
|
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from llmtuner.webui.interface import create_ui
|
from llmtuner import create_ui
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -5,17 +5,14 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.tuner import get_infer_args
|
from llmtuner import Manager, WebChatModel, create_chat_box
|
||||||
from llmtuner.webui.chat import WebChatModel
|
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
|
||||||
from llmtuner.webui.manager import Manager
|
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = WebChatModel(*get_infer_args())
|
chat_model = WebChatModel()
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo") as demo:
|
with gr.Blocks(title="Web Demo") as demo:
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user