diff --git a/README.md b/README.md index 69dfe649..c8e71e9c 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) + - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) - For reward modeling or DPO training: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) diff --git a/README_zh.md b/README_zh.md index 628a2b10..212016b0 100644 --- a/README_zh.md +++ b/README_zh.md @@ -105,8 +105,8 @@ - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [UltraChat (en)](https://github.com/thunlp/UltraChat) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - - [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) -- 用于奖励模型或 DPO 训练: + - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen) +- 用于训练奖励模型或 DPO 训练: - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) diff --git a/data/README.md b/data/README.md index a7375b5d..3be493b2 100644 --- a/data/README.md +++ b/data/README.md @@ -6,13 +6,13 @@ If you are using a custom dataset, please provide your dataset definition in the "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional)", + "ranking": "whether the examples contains ranked responses or not. (default: false)", "columns": { "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", "query": "the name of the column in the datasets containing the queries. (default: input)", "response": "the name of the column in the datasets containing the responses. (default: output)", "history": "the name of the column in the datasets containing the history of chat. (default: None)" - }, - "stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)" + } } ``` @@ -27,7 +27,6 @@ For datasets used in reward modeling or DPO training, the `response` column shou "output": [ "Chosen answer", "Rejected answer" - ], - "stage": "rm" + ] } ``` diff --git a/data/README_zh.md b/data/README_zh.md index e23a3e70..a8f62ca2 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -6,19 +6,19 @@ "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_sha1": "数据集文件的SHA-1哈希值(可选)", + "ranking": "数据集是否包含排序后的回答(默认:false)", "columns": { "prompt": "数据集代表提示词的表头名称(默认:instruction)", "query": "数据集代表请求的表头名称(默认:input)", "response": "数据集代表回答的表头名称(默认:output)", "history": "数据集代表历史对话的表头名称(默认:None)" - }, - "stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None,表示不限制" + } } ``` 其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。 -对于奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如: +对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如: ```json { @@ -27,7 +27,6 @@ "output": [ "Chosen answer", "Rejected answer" - ], - "stage": "rm" + ] } ``` diff --git a/requirements.txt b/requirements.txt index fb5fa72f..0c725b8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.29.1 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 -trl>=0.5.0 +trl>=0.7.1 scipy sentencepiece tiktoken diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 1fc146f8..efe7e97e 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -31,11 +31,13 @@ def preprocess_dataset( yield query, response, history, system def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: - # build grouped texts with format `X1 X2 X3 ...` (without ) - if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) - kwargs = dict(allowed_special="all") - else: - kwargs = dict(add_special_tokens=False) + # build grouped texts with format `X1 X2 X3 ...` + if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): + kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen) + + if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"): + setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer + setattr(tokenizer, "add_eos_token", True) tokenized_examples = tokenizer(examples["prompt"], **kwargs) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index f10aaaa3..523ba91b 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -10,20 +10,12 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] METHODS = ["full", "freeze", "lora"] -STAGES = [ - "SFT", - "Reward Modeling", - "PPO", - "DPO", - "Pre-Training" -] - -DATASET_STAGE_MAP = { - "SFT": "sft", - "Pre-Training": "pt", +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", - "PPO": "sft", - "DPO": "rm" + "PPO": "ppo", + "DPO": "dpo", + "Pre-Training": "pt" } SUPPORTED_MODELS = { diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 33e2c5f7..beb8e1f9 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -137,6 +137,8 @@ class Template: token_ids = [] for elem in context: if isinstance(elem, str): + if len(elem) == 0: + continue elem = elem.replace("{{system}}", system, 1) if system is not None else elem elem = elem.replace("{{query}}", query, 1) if query is not None else elem elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index db7702cd..63e8dacb 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -11,17 +11,15 @@ class DatasetAttr: dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None system_prompt: Optional[str] = None - stage: Optional[str] = None + ranking: Optional[bool] = False + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None def __repr__(self) -> str: return self.dataset_name - def __post_init__(self): - self.prompt = "instruction" - self.query = "input" - self.response = "output" - self.history = None - @dataclass class DataArguments: @@ -114,21 +112,14 @@ class DataArguments: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) if "hf_hub_url" in dataset_info[name]: - dataset_attr = DatasetAttr( - "hf_hub", - dataset_name=dataset_info[name]["hf_hub_url"], - stage=dataset_info[name].get("stage", None)) + dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) elif "script_url" in dataset_info[name]: - dataset_attr = DatasetAttr( - "script", - dataset_name=dataset_info[name]["script_url"], - stage=dataset_info[name].get("stage", None)) + dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: dataset_attr = DatasetAttr( "file", dataset_name=dataset_info[name]["file_name"], - dataset_sha1=dataset_info[name].get("file_sha1", None), - stage=dataset_info[name].get("stage", None) + dataset_sha1=dataset_info[name].get("file_sha1", None) ) if "columns" in dataset_info[name]: @@ -137,5 +128,6 @@ class DataArguments: dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None) + dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.system_prompt = prompt_list[i] self.dataset_list.append(dataset_attr) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index dc515f51..8969b0c1 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -16,7 +16,7 @@ class ModelArguments: metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} ) use_fast_tokenizer: Optional[bool] = field( - default=False, + default=True, metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} ) use_auth_token: Optional[bool] = field( diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 1f6d39b1..f85dbd8a 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -36,7 +36,7 @@ check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") -require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0") +require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1") def load_model_and_tokenizer( diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 4cd90af9..9d4b68bd 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -5,6 +5,7 @@ import datasets import transformers from typing import Any, Dict, Optional, Tuple from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.utils.versions import require_version from transformers.trainer_utils import get_last_checkpoint from llmtuner.extras.logging import get_logger @@ -110,6 +111,11 @@ def get_train_args( if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: raise ValueError("PPO and DPO stages can only be performed at training.") + if general_args.stage in ["rm", "dpo"]: + for dataset_attr in data_args.dataset_list: + if not dataset_attr.ranking: + raise ValueError("Please use ranked datasets for reward modeling or DPO training.") + if general_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") @@ -166,6 +172,7 @@ def get_train_args( and os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir ): + require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.") last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.") @@ -186,18 +193,6 @@ def get_train_args( else: model_args.compute_dtype = torch.float16 - # transfer training stage to dataset stage - dataset_stage = general_args.stage - if general_args.stage == "ppo": - dataset_stage = "sft" - elif general_args.stage == "dpo": - dataset_stage = "rm" - - for dataset_attr in data_args.dataset_list: - if dataset_attr.stage and dataset_attr.stage != dataset_stage: - raise ValueError("Dataset {} is not supported for the stage {}" - .format(dataset_attr.dataset_name, general_args.stage)) - model_args.model_max_length = data_args.max_source_length + data_args.max_target_length # Log on each process the small summary: diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 3f7f7af5..1004bf60 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,9 +1,9 @@ import torch from collections import defaultdict -from peft import PeftModel from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from transformers import BatchEncoding, Trainer from trl import DPOTrainer +from trl.trainer.utils import disable_dropout_in_model from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.tuner.core.trainer import PeftModelMixin @@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): def __init__( self, finetuning_args: "FinetuningArguments", + model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, + disable_dropout: Optional[bool] = True, **kwargs ): + if disable_dropout: + disable_dropout_in_model(model) + if ref_model is not None: + disable_dropout_in_model(ref_model) + self.finetuning_args = finetuning_args self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning @@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): self.beta = finetuning_args.dpo_beta self._stored_metrics = defaultdict(lambda: defaultdict(list)) - Trainer.__init__(self, **kwargs) + Trainer.__init__(self, model=model, **kwargs) if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") if ref_model is not None: - self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.is_deepspeed_enabled: + self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model) + self.ref_model.eval() + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) def concatenated_forward( self, @@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): batch: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error - unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) - if not torch.is_grad_enabled(): - unwrapped_model.gradient_checkpointing_disable() - - if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model - with unwrapped_model.disable_adapter(): - all_logits = self.model( - input_ids=batch_copied["input_ids"], - attention_mask=batch_copied["attention_mask"], - return_dict=True - ).logits.to(torch.float32) - else: - all_logits = model( - input_ids=batch_copied["input_ids"], - attention_mask=batch_copied["attention_mask"], - return_dict=True - ).logits.to(torch.float32) - - if not torch.is_grad_enabled(): - unwrapped_model.gradient_checkpointing_enable() + all_logits = model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], + return_dict=True + ).logits.to(torch.float32) all_logps = self._get_batch_logps( all_logits, diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index cd34c531..59aede98 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, - return_logits: Optional[bool] = False + return_logits: Optional[bool] = False, + response_masks: Optional[torch.Tensor] = None ): r""" Calculates model outputs in multiple batches. @@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} query_batch = queries[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] @@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): start += attention_mask[j, :].nonzero()[0] end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] + masks[j, :start] = 0 masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] if return_logits: all_logits.append(logits) diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index c243e322..48cd703a 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -44,7 +44,6 @@ def run_ppo( ) if finetuning_args.ppo_score_norm: - require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git") ppo_config.use_score_scaling = True ppo_config.use_score_norm = True diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 98fface4..c5f22294 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -6,7 +6,7 @@ import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME -from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP +from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES DEFAULT_CACHE_DIR = "cache" @@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: +def list_dataset( + dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0] +) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - if stage: - dataset_stage = DATASET_STAGE_MAP[stage] - dataset_info = {key: value for key, value in dataset_info.items() - if ("stage" not in value) or value["stage"] == dataset_stage} - - return gr.update(value=[], choices=list(dataset_info.keys())) \ No newline at end of file + ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] + datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] + return gr.update(value=[], choices=datasets) diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 7b69944c..90bf56bf 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType import gradio as gr -from llmtuner.extras.constants import STAGES +from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.utils import can_preview, get_preview, gen_plot @@ -15,7 +15,9 @@ if TYPE_CHECKING: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: with gr.Row(): - training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2) + training_stage = gr.Dropdown( + choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 + ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset = gr.Dropdown(multiselect=True, scale=4) data_preview_btn = gr.Button(interactive=False, scale=1) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 9e13a651..9f127852 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from typing import Any, Dict, Generator, List, Tuple from llmtuner.extras.callbacks import LogCallback -from llmtuner.extras.constants import DEFAULT_MODULE +from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp @@ -106,7 +106,7 @@ class Runner: output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) args = dict( - stage="sft", + stage=TRAINING_STAGES[training_stage], model_name_or_path=get_model_path(model_name), do_train=True, overwrite_cache=True, @@ -133,26 +133,20 @@ class Runner: lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), - resume_lora_training=resume_lora_training, + resume_lora_training=( + False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training + ), output_dir=output_dir ) args[compute_type] = True - if training_stage == "Reward Modeling": - args["stage"] = "rm" - args["resume_lora_training"] = False - elif training_stage == "PPO": - args["stage"] = "ppo" - args["resume_lora_training"] = False + if args["stage"] == "ppo": args["reward_model"] = reward_model args["padding_side"] = "left" val_size = 0 - elif training_stage == "DPO": - args["stage"] = "dpo" - args["resume_lora_training"] = False + + if args["stage"] == "dpo": args["dpo_beta"] = dpo_beta - elif training_stage == "Pre-Training": - args["stage"] = "pt" if val_size > 1e-6: args["val_size"] = val_size diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index d32d719c..32625cba 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,10 +3,9 @@ import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from datetime import datetime -from llmtuner.dsets.utils import EXT2TYPE from llmtuner.extras.ploting import smooth from llmtuner.tuner import export_model from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG @@ -37,6 +36,7 @@ def get_time() -> str: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: dataset_info = json.load(f) + if ( len(dataset) > 0 and "file_name" in dataset_info[dataset[0]] @@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: return gr.update(interactive=False) -def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]: +def get_preview( + dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2 +) -> Tuple[int, list, Dict[str, Any]]: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: dataset_info = json.load(f) - data_file = dataset_info[dataset[0]]["file_name"] - data = [] - data_format = EXT2TYPE.get(data_file.split(".")[-1], None) - if data_format == "text": - with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: - for line in f: - data.append(line.strip()) - elif data_format == "json": - with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: + + data_file: str = dataset_info[dataset[0]]["file_name"] + with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: + if data_file.endswith(".json"): data = json.load(f) - return len(data), data[:2], gr.update(visible=True) + elif data_file.endswith(".jsonl"): + data = [json.load(line) for line in f] + else: + data = [line for line in f] + return len(data), data[start:end], gr.update(visible=True) def can_quantize(finetuning_type: str) -> Dict[str, Any]: if finetuning_type != "lora": - return gr.update(value="", interactive=False) + return gr.update(value="None", interactive=False) else: return gr.update(interactive=True) @@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: if args.get("do_train", None): args["plot_loss"] = True - cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] + cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"] for k, v in args.items(): if v is not None and v != "": cmd_lines.append(" --{} {} ".format(k, str(v)))