refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: a9d1fb72f791ae57a4d12f4e3a7e2abccf6a7077
This commit is contained in:
hiyouga 2023-09-01 19:00:45 +08:00
parent 173022473d
commit a4fd976048
20 changed files with 160 additions and 198 deletions

View File

@ -105,7 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)

View File

@ -105,8 +105,8 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
- 用于奖励模型或 DPO 训练:
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- 用于训练奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)

View File

@ -6,13 +6,13 @@ If you are using a custom dataset, please provide your dataset definition in the
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"ranking": "whether the examples contains ranked responses or not. (default: false)",
"columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
},
"stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)"
}
}
```
@ -27,7 +27,6 @@ For datasets used in reward modeling or DPO training, the `response` column shou
"output": [
"Chosen answer",
"Rejected answer"
],
"stage": "rm"
]
}
```

View File

@ -6,19 +6,19 @@
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"ranking": "数据集是否包含排序后的回答默认false",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
},
"stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None表示不限制"
}
}
```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
对于奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
```json
{
@ -27,7 +27,6 @@
"output": [
"Chosen answer",
"Rejected answer"
],
"stage": "rm"
]
}
```

View File

@ -1,28 +1,23 @@
{
"alpaca_en": {
"file_name": "alpaca_data_en_52k.json",
"file_sha1": "607f94a7f581341e59685aef32f531095232cf23",
"stage": "sft"
"file_sha1": "607f94a7f581341e59685aef32f531095232cf23"
},
"alpaca_zh": {
"file_name": "alpaca_data_zh_51k.json",
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311",
"stage": "sft"
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
},
"alpaca_gpt4_en": {
"file_name": "alpaca_gpt4_data_en.json",
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a",
"stage": "sft"
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
},
"alpaca_gpt4_zh": {
"file_name": "alpaca_gpt4_data_zh.json",
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845",
"stage": "sft"
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
},
"self_cognition": {
"file_name": "self_cognition.json",
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67",
"stage": "sft"
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67"
},
"oaast_sft": {
"file_name": "oaast_sft.json",
@ -32,8 +27,7 @@
"query": "input",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"oaast_sft_zh": {
"file_name": "oaast_sft_zh.json",
@ -43,8 +37,7 @@
"query": "input",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"sharegpt_zh": {
"file_name": "sharegpt_zh_27k.json",
@ -54,8 +47,7 @@
"query": "input",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"lima": {
"file_name": "lima.json",
@ -65,8 +57,7 @@
"query": "input",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"example": {
"script_url": "example_dataset",
@ -75,32 +66,25 @@
"query": "input",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"guanaco": {
"hf_hub_url": "JosephusCheung/GuanacoDataset",
"stage": "sft"
"hf_hub_url": "JosephusCheung/GuanacoDataset"
},
"belle_0.5m": {
"hf_hub_url": "BelleGroup/train_0.5M_CN",
"stage": "sft"
"hf_hub_url": "BelleGroup/train_0.5M_CN"
},
"belle_1m": {
"hf_hub_url": "BelleGroup/train_1M_CN",
"stage": "sft"
"hf_hub_url": "BelleGroup/train_1M_CN"
},
"belle_2m": {
"hf_hub_url": "BelleGroup/train_2M_CN",
"stage": "sft"
"hf_hub_url": "BelleGroup/train_2M_CN"
},
"belle_dialog": {
"hf_hub_url": "BelleGroup/generated_chat_0.4M",
"stage": "sft"
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
},
"belle_math": {
"hf_hub_url": "BelleGroup/school_math_0.25M",
"stage": "sft"
"hf_hub_url": "BelleGroup/school_math_0.25M"
},
"belle_multiturn": {
"script_url": "belle_multiturn",
@ -109,8 +93,7 @@
"query": "",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"firefly": {
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
@ -119,16 +102,13 @@
"query": "",
"response": "target",
"history": ""
},
"stage": "sft"
}
},
"codealpaca": {
"hf_hub_url": "sahil2801/CodeAlpaca-20k",
"stage": "sft"
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
},
"alpaca_cot": {
"hf_hub_url": "QingyiSi/Alpaca-CoT",
"stage": "sft"
"hf_hub_url": "QingyiSi/Alpaca-CoT"
},
"webqa": {
"hf_hub_url": "suolyer/webqa",
@ -137,8 +117,7 @@
"query": "",
"response": "output",
"history": ""
},
"stage": "sft"
}
},
"ultra_chat": {
"script_url": "ultra_chat",
@ -147,32 +126,29 @@
"query": "",
"response": "output",
"history": "history"
},
"stage": "sft"
}
},
"novel_tokens512_50k": {
"hf_hub_url": "zxbsmk/webnovel_cn",
"stage": "sft"
"hf_hub_url": "zxbsmk/webnovel_cn"
},
"ad_gen": {
"adgen": {
"hf_hub_url": "HasturOfficial/adgen",
"columns": {
"prompt": "content",
"query": "",
"response": "summary",
"history": ""
},
"stage": "sft"
}
},
"comparison_gpt4_en": {
"file_name": "comparison_gpt4_data_en.json",
"file_sha1": "96fa18313544e22444fe20eead7754b17da452ae",
"stage": "rm"
"ranking": true
},
"comparison_gpt4_zh": {
"file_name": "comparison_gpt4_data_zh.json",
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
"stage": "rm"
"ranking": true
},
"hh_rlhf_en": {
"script_url": "hh_rlhf_en",
@ -182,7 +158,7 @@
"response": "output",
"history": "history"
},
"stage": "rm"
"ranking": true
},
"oaast_rm": {
"file_name": "oaast_rm.json",
@ -193,7 +169,7 @@
"response": "output",
"history": "history"
},
"stage": "rm"
"ranking": true
},
"oaast_rm_zh": {
"file_name": "oaast_rm_zh.json",
@ -204,7 +180,7 @@
"response": "output",
"history": "history"
},
"stage": "rm"
"ranking": true
},
"wiki_demo": {
"file_name": "wiki_demo.txt",
@ -214,8 +190,7 @@
"query": "",
"response": "",
"history": ""
},
"stage": "pt"
}
},
"refinedweb": {
"hf_hub_url": "tiiuae/falcon-refinedweb",
@ -224,18 +199,7 @@
"query": "",
"response": "",
"history": ""
},
"stage": "pt"
},
"starcoder": {
"hf_hub_url": "bigcode/starcoderdata",
"columns": {
"prompt": "content",
"query": "",
"response": "",
"history": ""
},
"stage": "pt"
}
},
"wikipedia_en": {
"hf_hub_url": "olm/olm-wikipedia-20221220",
@ -244,8 +208,7 @@
"query": "",
"response": "",
"history": ""
},
"stage": "pt"
}
},
"wikipedia_zh": {
"hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered",
@ -254,7 +217,24 @@
"query": "",
"response": "",
"history": ""
},
"stage": "pt"
}
},
"the_stack": {
"hf_hub_url": "bigcode/the-stack",
"columns": {
"prompt": "content",
"query": "",
"response": "",
"history": ""
}
},
"starcoder": {
"hf_hub_url": "bigcode/starcoderdata",
"columns": {
"prompt": "content",
"query": "",
"response": "",
"history": ""
}
}
}

View File

@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
trl>=0.5.0
trl>=0.7.1
scipy
sentencepiece
tiktoken

View File

@ -31,11 +31,13 @@ def preprocess_dataset(
yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
# build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@ -10,20 +10,12 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"]
STAGES = [
"SFT",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
TRAINING_STAGES = {
"Supervised Fine-Tuning": "sft",
"Reward Modeling": "rm",
"PPO": "sft",
"DPO": "rm"
"PPO": "ppo",
"DPO": "dpo",
"Pre-Training": "pt"
}
SUPPORTED_MODELS = {

View File

@ -137,6 +137,8 @@ class Template:
token_ids = []
for elem in context:
if isinstance(elem, str):
if len(elem) == 0:
continue
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem

View File

@ -11,17 +11,15 @@ class DatasetAttr:
dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None
stage: Optional[str] = None
ranking: Optional[bool] = False
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
def __repr__(self) -> str:
return self.dataset_name
def __post_init__(self):
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass
class DataArguments:
@ -114,21 +112,14 @@ class DataArguments:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"],
stage=dataset_info[name].get("stage", None))
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr(
"script",
dataset_name=dataset_info[name]["script_url"],
stage=dataset_info[name].get("stage", None))
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
dataset_attr = DatasetAttr(
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None),
stage=dataset_info[name].get("stage", None)
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
if "columns" in dataset_info[name]:
@ -137,5 +128,6 @@ class DataArguments:
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr)

View File

@ -16,7 +16,7 @@ class ModelArguments:
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
)
use_fast_tokenizer: Optional[bool] = field(
default=False,
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
)
use_auth_token: Optional[bool] = field(

View File

@ -36,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
def load_model_and_tokenizer(

View File

@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
@ -110,6 +111,11 @@ def get_train_args(
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
@ -166,6 +172,7 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
@ -186,18 +193,6 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:

View File

@ -1,9 +1,9 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,

View File

@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
):
r"""
Calculates model outputs in multiple batches.
@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)

View File

@ -44,7 +44,6 @@ def run_ppo(
)
if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
DEFAULT_CACHE_DIR = "cache"
@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {}
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
def list_dataset(
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.update(value=[], choices=datasets)

View File

@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.extras.constants import STAGES
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@ -15,7 +15,9 @@ if TYPE_CHECKING:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)

View File

@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
@ -106,7 +106,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
stage="sft",
stage=TRAINING_STAGES[training_stage],
model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
@ -133,26 +133,20 @@ class Runner:
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
resume_lora_training=(
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
),
output_dir=output_dir
)
args[compute_type] = True
if training_stage == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
if args["stage"] == "ppo":
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
if args["stage"] == "dpo":
args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6:
args["val_size"] = val_size

View File

@ -3,10 +3,9 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from datetime import datetime
from llmtuner.dsets.utils import EXT2TYPE
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
@ -37,6 +36,7 @@ def get_time() -> str:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
def get_preview(
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
data = []
data_format = EXT2TYPE.get(data_file.split(".")[-1], None)
if data_format == "text":
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
for line in f:
data.append(line.strip())
elif data_format == "json":
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_file.endswith(".json"):
data = json.load(f)
return len(data), data[:2], gr.update(visible=True)
elif data_file.endswith(".jsonl"):
data = [json.load(line) for line in f]
else:
data = [line for line in f]
return len(data), data[start:end], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
return gr.update(value="None", interactive=False)
else:
return gr.update(interactive=True)
@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"]
for k, v in args.items():
if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))