mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: a9d1fb72f791ae57a4d12f4e3a7e2abccf6a7077
This commit is contained in:
parent
173022473d
commit
a4fd976048
@ -105,7 +105,7 @@
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- For reward modeling or DPO training:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
|
@ -105,8 +105,8 @@
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605)
|
||||
- 用于奖励模型或 DPO 训练:
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- 用于训练奖励模型或 DPO 训练:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
@ -6,13 +6,13 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
|
||||
"ranking": "whether the examples contains ranked responses or not. (default: false)",
|
||||
"columns": {
|
||||
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
|
||||
"query": "the name of the column in the datasets containing the queries. (default: input)",
|
||||
"response": "the name of the column in the datasets containing the responses. (default: output)",
|
||||
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
|
||||
},
|
||||
"stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@ -27,7 +27,6 @@ For datasets used in reward modeling or DPO training, the `response` column shou
|
||||
"output": [
|
||||
"Chosen answer",
|
||||
"Rejected answer"
|
||||
],
|
||||
"stage": "rm"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
@ -6,19 +6,19 @@
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"ranking": "数据集是否包含排序后的回答(默认:false)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
},
|
||||
"stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None,表示不限制"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
|
||||
|
||||
对于奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
|
||||
对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
|
||||
|
||||
```json
|
||||
{
|
||||
@ -27,7 +27,6 @@
|
||||
"output": [
|
||||
"Chosen answer",
|
||||
"Rejected answer"
|
||||
],
|
||||
"stage": "rm"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
@ -1,28 +1,23 @@
|
||||
{
|
||||
"alpaca_en": {
|
||||
"file_name": "alpaca_data_en_52k.json",
|
||||
"file_sha1": "607f94a7f581341e59685aef32f531095232cf23",
|
||||
"stage": "sft"
|
||||
"file_sha1": "607f94a7f581341e59685aef32f531095232cf23"
|
||||
},
|
||||
"alpaca_zh": {
|
||||
"file_name": "alpaca_data_zh_51k.json",
|
||||
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311",
|
||||
"stage": "sft"
|
||||
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
|
||||
},
|
||||
"alpaca_gpt4_en": {
|
||||
"file_name": "alpaca_gpt4_data_en.json",
|
||||
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a",
|
||||
"stage": "sft"
|
||||
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
|
||||
},
|
||||
"alpaca_gpt4_zh": {
|
||||
"file_name": "alpaca_gpt4_data_zh.json",
|
||||
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845",
|
||||
"stage": "sft"
|
||||
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
|
||||
},
|
||||
"self_cognition": {
|
||||
"file_name": "self_cognition.json",
|
||||
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67",
|
||||
"stage": "sft"
|
||||
"file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67"
|
||||
},
|
||||
"oaast_sft": {
|
||||
"file_name": "oaast_sft.json",
|
||||
@ -32,8 +27,7 @@
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"oaast_sft_zh": {
|
||||
"file_name": "oaast_sft_zh.json",
|
||||
@ -43,8 +37,7 @@
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"sharegpt_zh": {
|
||||
"file_name": "sharegpt_zh_27k.json",
|
||||
@ -54,8 +47,7 @@
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"lima": {
|
||||
"file_name": "lima.json",
|
||||
@ -65,8 +57,7 @@
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"example": {
|
||||
"script_url": "example_dataset",
|
||||
@ -75,32 +66,25 @@
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"guanaco": {
|
||||
"hf_hub_url": "JosephusCheung/GuanacoDataset",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "JosephusCheung/GuanacoDataset"
|
||||
},
|
||||
"belle_0.5m": {
|
||||
"hf_hub_url": "BelleGroup/train_0.5M_CN",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "BelleGroup/train_0.5M_CN"
|
||||
},
|
||||
"belle_1m": {
|
||||
"hf_hub_url": "BelleGroup/train_1M_CN",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "BelleGroup/train_1M_CN"
|
||||
},
|
||||
"belle_2m": {
|
||||
"hf_hub_url": "BelleGroup/train_2M_CN",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "BelleGroup/train_2M_CN"
|
||||
},
|
||||
"belle_dialog": {
|
||||
"hf_hub_url": "BelleGroup/generated_chat_0.4M",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
|
||||
},
|
||||
"belle_math": {
|
||||
"hf_hub_url": "BelleGroup/school_math_0.25M",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "BelleGroup/school_math_0.25M"
|
||||
},
|
||||
"belle_multiturn": {
|
||||
"script_url": "belle_multiturn",
|
||||
@ -109,8 +93,7 @@
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"firefly": {
|
||||
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
|
||||
@ -119,16 +102,13 @@
|
||||
"query": "",
|
||||
"response": "target",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"codealpaca": {
|
||||
"hf_hub_url": "sahil2801/CodeAlpaca-20k",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
|
||||
},
|
||||
"alpaca_cot": {
|
||||
"hf_hub_url": "QingyiSi/Alpaca-CoT",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
||||
},
|
||||
"webqa": {
|
||||
"hf_hub_url": "suolyer/webqa",
|
||||
@ -137,8 +117,7 @@
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"ultra_chat": {
|
||||
"script_url": "ultra_chat",
|
||||
@ -147,32 +126,29 @@
|
||||
"query": "",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"novel_tokens512_50k": {
|
||||
"hf_hub_url": "zxbsmk/webnovel_cn",
|
||||
"stage": "sft"
|
||||
"hf_hub_url": "zxbsmk/webnovel_cn"
|
||||
},
|
||||
"ad_gen": {
|
||||
"adgen": {
|
||||
"hf_hub_url": "HasturOfficial/adgen",
|
||||
"columns": {
|
||||
"prompt": "content",
|
||||
"query": "",
|
||||
"response": "summary",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "sft"
|
||||
}
|
||||
},
|
||||
"comparison_gpt4_en": {
|
||||
"file_name": "comparison_gpt4_data_en.json",
|
||||
"file_sha1": "96fa18313544e22444fe20eead7754b17da452ae",
|
||||
"stage": "rm"
|
||||
"ranking": true
|
||||
},
|
||||
"comparison_gpt4_zh": {
|
||||
"file_name": "comparison_gpt4_data_zh.json",
|
||||
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
|
||||
"stage": "rm"
|
||||
"ranking": true
|
||||
},
|
||||
"hh_rlhf_en": {
|
||||
"script_url": "hh_rlhf_en",
|
||||
@ -182,7 +158,7 @@
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "rm"
|
||||
"ranking": true
|
||||
},
|
||||
"oaast_rm": {
|
||||
"file_name": "oaast_rm.json",
|
||||
@ -193,7 +169,7 @@
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "rm"
|
||||
"ranking": true
|
||||
},
|
||||
"oaast_rm_zh": {
|
||||
"file_name": "oaast_rm_zh.json",
|
||||
@ -204,7 +180,7 @@
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
},
|
||||
"stage": "rm"
|
||||
"ranking": true
|
||||
},
|
||||
"wiki_demo": {
|
||||
"file_name": "wiki_demo.txt",
|
||||
@ -214,8 +190,7 @@
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "pt"
|
||||
}
|
||||
},
|
||||
"refinedweb": {
|
||||
"hf_hub_url": "tiiuae/falcon-refinedweb",
|
||||
@ -224,18 +199,7 @@
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "pt"
|
||||
},
|
||||
"starcoder": {
|
||||
"hf_hub_url": "bigcode/starcoderdata",
|
||||
"columns": {
|
||||
"prompt": "content",
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "pt"
|
||||
}
|
||||
},
|
||||
"wikipedia_en": {
|
||||
"hf_hub_url": "olm/olm-wikipedia-20221220",
|
||||
@ -244,8 +208,7 @@
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "pt"
|
||||
}
|
||||
},
|
||||
"wikipedia_zh": {
|
||||
"hf_hub_url": "pleisto/wikipedia-cn-20230720-filtered",
|
||||
@ -254,7 +217,24 @@
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
},
|
||||
"stage": "pt"
|
||||
}
|
||||
},
|
||||
"the_stack": {
|
||||
"hf_hub_url": "bigcode/the-stack",
|
||||
"columns": {
|
||||
"prompt": "content",
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
}
|
||||
},
|
||||
"starcoder": {
|
||||
"hf_hub_url": "bigcode/starcoderdata",
|
||||
"columns": {
|
||||
"prompt": "content",
|
||||
"query": "",
|
||||
"response": "",
|
||||
"history": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,7 @@ transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.5.0
|
||||
trl>=0.7.1
|
||||
scipy
|
||||
sentencepiece
|
||||
tiktoken
|
||||
|
@ -31,11 +31,13 @@ def preprocess_dataset(
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=False)
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
|
||||
kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
|
||||
|
||||
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
|
||||
setattr(tokenizer, "add_eos_token", True)
|
||||
|
||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
|
@ -10,20 +10,12 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
STAGES = [
|
||||
"SFT",
|
||||
"Reward Modeling",
|
||||
"PPO",
|
||||
"DPO",
|
||||
"Pre-Training"
|
||||
]
|
||||
|
||||
DATASET_STAGE_MAP = {
|
||||
"SFT": "sft",
|
||||
"Pre-Training": "pt",
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
"PPO": "sft",
|
||||
"DPO": "rm"
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
|
@ -137,6 +137,8 @@ class Template:
|
||||
token_ids = []
|
||||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
if len(elem) == 0:
|
||||
continue
|
||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||
|
@ -11,17 +11,15 @@ class DatasetAttr:
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
stage: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt = "instruction"
|
||||
self.query = "input"
|
||||
self.response = "output"
|
||||
self.history = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
@ -114,21 +112,14 @@ class DataArguments:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr(
|
||||
"hf_hub",
|
||||
dataset_name=dataset_info[name]["hf_hub_url"],
|
||||
stage=dataset_info[name].get("stage", None))
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr(
|
||||
"script",
|
||||
dataset_name=dataset_info[name]["script_url"],
|
||||
stage=dataset_info[name].get("stage", None))
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||
stage=dataset_info[name].get("stage", None)
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
@ -137,5 +128,6 @@ class DataArguments:
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
@ -16,7 +16,7 @@ class ModelArguments:
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=False,
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
|
@ -36,7 +36,7 @@ check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
@ -5,6 +5,7 @@ import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@ -110,6 +111,11 @@ def get_train_args(
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
@ -166,6 +172,7 @@ def get_train_args(
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
@ -186,18 +193,6 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
# transfer training stage to dataset stage
|
||||
dataset_stage = general_args.stage
|
||||
if general_args.stage == "ppo":
|
||||
dataset_stage = "sft"
|
||||
elif general_args.stage == "dpo":
|
||||
dataset_stage = "rm"
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
|
||||
raise ValueError("Dataset {} is not supported for the stage {}"
|
||||
.format(dataset_attr.dataset_name, general_args.stage))
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
# Log on each process the small summary:
|
||||
|
@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model.eval()
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
|
||||
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||
with unwrapped_model.disable_adapter():
|
||||
all_logits = self.model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
else:
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
|
@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||
if response_masks is not None:
|
||||
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
if response_masks is not None:
|
||||
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
||||
|
||||
if return_logits:
|
||||
all_logits.append(logits)
|
||||
|
@ -44,7 +44,6 @@ def run_ppo(
|
||||
)
|
||||
|
||||
if finetuning_args.ppo_score_norm:
|
||||
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
|
||||
ppo_config.use_score_scaling = True
|
||||
ppo_config.use_score_norm = True
|
||||
|
||||
|
@ -6,7 +6,7 @@ import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
|
||||
def list_dataset(
|
||||
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
|
||||
) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
if stage:
|
||||
dataset_stage = DATASET_STAGE_MAP[stage]
|
||||
dataset_info = {key: value for key, value in dataset_info.items()
|
||||
if ("stage" not in value) or value["stage"] == dataset_stage}
|
||||
|
||||
return gr.update(value=[], choices=list(dataset_info.keys()))
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
|
@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.extras.constants import STAGES
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
@ -15,7 +15,9 @@ if TYPE_CHECKING:
|
||||
|
||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
@ -106,7 +106,7 @@ class Runner:
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
stage=TRAINING_STAGES[training_stage],
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
@ -133,26 +133,20 @@ class Runner:
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=resume_lora_training,
|
||||
resume_lora_training=(
|
||||
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
|
||||
),
|
||||
output_dir=output_dir
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if training_stage == "Reward Modeling":
|
||||
args["stage"] = "rm"
|
||||
args["resume_lora_training"] = False
|
||||
elif training_stage == "PPO":
|
||||
args["stage"] = "ppo"
|
||||
args["resume_lora_training"] = False
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = reward_model
|
||||
args["padding_side"] = "left"
|
||||
val_size = 0
|
||||
elif training_stage == "DPO":
|
||||
args["stage"] = "dpo"
|
||||
args["resume_lora_training"] = False
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = dpo_beta
|
||||
elif training_stage == "Pre-Training":
|
||||
args["stage"] = "pt"
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
|
@ -3,10 +3,9 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.dsets.utils import EXT2TYPE
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
@ -37,6 +36,7 @@ def get_time() -> str:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
def get_preview(
|
||||
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
|
||||
) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
data = []
|
||||
data_format = EXT2TYPE.get(data_file.split(".")[-1], None)
|
||||
if data_format == "text":
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data.append(line.strip())
|
||||
elif data_format == "json":
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.load(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[start:end], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
return gr.update(value="None", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
if args.get("do_train", None):
|
||||
args["plot_loss"] = True
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
Loading…
x
Reference in New Issue
Block a user