refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
hiyouga 2023-09-01 19:00:45 +08:00
parent 93be211f80
commit e5b72c6a77
19 changed files with 108 additions and 126 deletions

View File

@ -105,7 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- For reward modeling or DPO training: - For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)

View File

@ -105,8 +105,8 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Ad Gen (zh)](https://arxiv.org/abs/1908.06605) - [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
- 用于奖励模型或 DPO 训练: - 用于训练奖励模型或 DPO 训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)

View File

@ -6,13 +6,13 @@ If you are using a custom dataset, please provide your dataset definition in the
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"ranking": "whether the examples contains ranked responses or not. (default: false)",
"columns": { "columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)", "query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)", "response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)" "history": "the name of the column in the datasets containing the history of chat. (default: None)"
}, }
"stage": "The stage at which the data is being used: pt, sft, and rm, which correspond to pre-training, supervised fine-tuning(PPO), and reward model (DPO) training, respectively.(default: None)"
} }
``` ```
@ -27,7 +27,6 @@ For datasets used in reward modeling or DPO training, the `response` column shou
"output": [ "output": [
"Chosen answer", "Chosen answer",
"Rejected answer" "Rejected answer"
], ]
"stage": "rm"
} }
``` ```

View File

@ -6,19 +6,19 @@
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)", "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)", "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选", "file_sha1": "数据集文件的SHA-1哈希值可选",
"ranking": "数据集是否包含排序后的回答默认false",
"columns": { "columns": {
"prompt": "数据集代表提示词的表头名称默认instruction", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output", "response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None" "history": "数据集代表历史对话的表头名称默认None"
}, }
"stage": "数据所应用的训练阶段,可选值有 pt, sft, rm 三个,对应预训练,指令监督微调(PPO),奖励模型(DPO)训练, 默认为None表示不限制"
} }
``` ```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。 其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
对于奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如: 对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
```json ```json
{ {
@ -27,7 +27,6 @@
"output": [ "output": [
"Chosen answer", "Chosen answer",
"Rejected answer" "Rejected answer"
], ]
"stage": "rm"
} }
``` ```

View File

@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.5.0 trl>=0.7.1
scipy scipy
sentencepiece sentencepiece
tiktoken tiktoken

View File

@ -31,11 +31,13 @@ def preprocess_dataset(
yield query, response, history, system yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...`
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
kwargs = dict(allowed_special="all") kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
else:
kwargs = dict(add_special_tokens=False) if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
setattr(tokenizer, "add_eos_token", True)
tokenized_examples = tokenizer(examples["prompt"], **kwargs) tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@ -10,20 +10,12 @@ LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
STAGES = [ TRAINING_STAGES = {
"SFT", "Supervised Fine-Tuning": "sft",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
DATASET_STAGE_MAP = {
"SFT": "sft",
"Pre-Training": "pt",
"Reward Modeling": "rm", "Reward Modeling": "rm",
"PPO": "sft", "PPO": "ppo",
"DPO": "rm" "DPO": "dpo",
"Pre-Training": "pt"
} }
SUPPORTED_MODELS = { SUPPORTED_MODELS = {

View File

@ -137,6 +137,8 @@ class Template:
token_ids = [] token_ids = []
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
if len(elem) == 0:
continue
elem = elem.replace("{{system}}", system, 1) if system is not None else elem elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem

View File

@ -11,17 +11,15 @@ class DatasetAttr:
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
system_prompt: Optional[str] = None system_prompt: Optional[str] = None
stage: Optional[str] = None ranking: Optional[bool] = False
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def __post_init__(self):
self.prompt = "instruction"
self.query = "input"
self.response = "output"
self.history = None
@dataclass @dataclass
class DataArguments: class DataArguments:
@ -114,21 +112,14 @@ class DataArguments:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
"hf_hub",
dataset_name=dataset_info[name]["hf_hub_url"],
stage=dataset_info[name].get("stage", None))
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
"script",
dataset_name=dataset_info[name]["script_url"],
stage=dataset_info[name].get("stage", None))
else: else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
dataset_name=dataset_info[name]["file_name"], dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None), dataset_sha1=dataset_info[name].get("file_sha1", None)
stage=dataset_info[name].get("stage", None)
) )
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
@ -137,5 +128,6 @@ class DataArguments:
dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.ranking = dataset_info[name].get("ranking", False)
dataset_attr.system_prompt = prompt_list[i] dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@ -16,7 +16,7 @@ class ModelArguments:
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
) )
use_fast_tokenizer: Optional[bool] = field( use_fast_tokenizer: Optional[bool] = field(
default=False, default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
) )
use_auth_token: Optional[bool] = field( use_auth_token: Optional[bool] = field(

View File

@ -36,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0") require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@ -5,6 +5,7 @@ import datasets
import transformers import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
@ -110,6 +111,11 @@ def get_train_args(
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train: if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.") raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None: if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
@ -166,6 +172,7 @@ def get_train_args(
and os.path.isdir(training_args.output_dir) and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir and not training_args.overwrite_output_dir
): ):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.") raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
@ -186,18 +193,6 @@ def get_train_args(
else: else:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: # Log on each process the small summary:

View File

@ -1,9 +1,9 @@
import torch import torch
from collections import defaultdict from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin from llmtuner.tuner.core.trainer import PeftModelMixin
@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__( def __init__(
self, self,
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs **kwargs
): ):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.ref_model = ref_model self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self.beta = finetuning_args.dpo_beta self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs) Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"): if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.") raise AttributeError("Please update `transformers`.")
if ref_model is not None: if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) if self.is_deepspeed_enabled:
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward( def concatenated_forward(
self, self,
@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
batch: Optional[Dict[str, torch.Tensor]] = None batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled(): all_logits = model(
unwrapped_model.gradient_checkpointing_disable() input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model return_dict=True
with unwrapped_model.disable_adapter(): ).logits.to(torch.float32)
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps( all_logps = self._get_batch_logps(
all_logits, all_logits,

View File

@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor, queries: torch.Tensor,
responses: torch.Tensor, responses: torch.Tensor,
model_inputs: dict, model_inputs: dict,
return_logits: Optional[bool] = False return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
): ):
r""" r"""
Calculates model outputs in multiple batches. Calculates model outputs in multiple batches.
@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs] query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"] input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"] attention_mask = input_kwargs["attention_mask"]
@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
start += attention_mask[j, :].nonzero()[0] start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j]) end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0 masks[j, :start] = 0
masks[j, end:] = 0 masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits: if return_logits:
all_logits.append(logits) all_logits.append(logits)

View File

@ -44,7 +44,6 @@ def run_ppo(
) )
if finetuning_args.ppo_score_norm: if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True ppo_config.use_score_norm = True

View File

@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {} return {}
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]: def list_dataset(
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
if stage: ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
dataset_stage = DATASET_STAGE_MAP[stage] datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
dataset_info = {key: value for key, value in dataset_info.items() return gr.update(value=[], choices=datasets)
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))

View File

@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr import gradio as gr
from llmtuner.extras.constants import STAGES from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@ -15,7 +15,9 @@ if TYPE_CHECKING:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2) training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)

View File

@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp from llmtuner.tuner import run_exp
@ -106,7 +106,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict( args = dict(
stage="sft", stage=TRAINING_STAGES[training_stage],
model_name_or_path=get_model_path(model_name), model_name_or_path=get_model_path(model_name),
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
@ -133,26 +133,20 @@ class Runner:
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training, resume_lora_training=(
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
),
output_dir=output_dir output_dir=output_dir
) )
args[compute_type] = True args[compute_type] = True
if training_stage == "Reward Modeling": if args["stage"] == "ppo":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model args["reward_model"] = reward_model
args["padding_side"] = "left" args["padding_side"] = "left"
val_size = 0 val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo" if args["stage"] == "dpo":
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6: if val_size > 1e-6:
args["val_size"] = val_size args["val_size"] = val_size

View File

@ -3,10 +3,9 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.dsets.utils import EXT2TYPE
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
@ -37,6 +36,7 @@ def get_time() -> str:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if ( if (
len(dataset) > 0 len(dataset) > 0
and "file_name" in dataset_info[dataset[0]] and "file_name" in dataset_info[dataset[0]]
@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
return gr.update(interactive=False) return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]: def get_preview(
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
data = [] data_file: str = dataset_info[dataset[0]]["file_name"]
data_format = EXT2TYPE.get(data_file.split(".")[-1], None) with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_format == "text": if data_file.endswith(".json"):
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
for line in f:
data.append(line.strip())
elif data_format == "json":
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
return len(data), data[:2], gr.update(visible=True) elif data_file.endswith(".jsonl"):
data = [json.load(line) for line in f]
else:
data = [line for line in f]
return len(data), data[start:end], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]: def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora": if finetuning_type != "lora":
return gr.update(value="", interactive=False) return gr.update(value="None", interactive=False)
else: else:
return gr.update(interactive=True) return gr.update(interactive=True)
@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None): if args.get("do_train", None):
args["plot_loss"] = True args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"]
for k, v in args.items(): for k, v in args.items():
if v is not None and v != "": if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v))) cmd_lines.append(" --{} {} ".format(k, str(v)))