mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
Former-commit-id: 779aae83d253de0a86201ff87543b5d695e28d23
This commit is contained in:
parent
c8e77c11d1
commit
34f16cc635
@ -63,18 +63,19 @@ def _encode_supervised_example(
|
|||||||
total_length += source_len + target_len
|
total_length += source_len + target_len
|
||||||
|
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
source_mask = source_ids
|
source_label = source_ids
|
||||||
elif turn_idx != 0 and template.efficient_eos:
|
elif turn_idx != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * source_len
|
source_label = [IGNORE_INDEX] * source_len
|
||||||
|
|
||||||
|
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
|
||||||
|
target_label = [IGNORE_INDEX] * target_len
|
||||||
|
else:
|
||||||
|
target_label = target_ids
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
|
labels += source_label + target_label
|
||||||
if data_args.train_last_turn_only and turn_idx != len(encoded_pairs) - 1:
|
|
||||||
labels += source_mask + [IGNORE_INDEX] * len(target_ids)
|
|
||||||
else:
|
|
||||||
labels += source_mask + target_ids
|
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
input_ids += [tokenizer.eos_token_id]
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
@ -41,17 +41,17 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
train_last_turn_only: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether or not to train the last turn only."},
|
|
||||||
)
|
|
||||||
cutoff_len: int = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
)
|
)
|
||||||
train_on_prompt: bool = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
metadata={"help": "Whether or not to disable the mask on the prompt."},
|
||||||
|
)
|
||||||
|
mask_history: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to mask the history and train on the last turn only."},
|
||||||
)
|
)
|
||||||
streaming: bool = field(
|
streaming: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -162,9 +162,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
# Check arguments
|
# Check arguments
|
||||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
|
||||||
if finetuning_args.stage == "pt" and data_args.train_last_turn_only:
|
|
||||||
raise ValueError("PT stage does not support `train_last_turn_only`.")
|
|
||||||
|
|
||||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
@ -44,11 +44,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
)
|
)
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1)
|
||||||
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4)
|
||||||
train_last_turn_only = gr.Checkbox()
|
|
||||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||||
|
|
||||||
input_elems.update({training_stage, dataset_dir, dataset,train_last_turn_only})
|
input_elems.update({training_stage, dataset_dir, dataset})
|
||||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset,train_last_turn_only=train_last_turn_only, **preview_elems))
|
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
learning_rate = gr.Textbox(value="5e-5")
|
learning_rate = gr.Textbox(value="5e-5")
|
||||||
@ -99,6 +98,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
packing = gr.Checkbox()
|
packing = gr.Checkbox()
|
||||||
neat_packing = gr.Checkbox()
|
neat_packing = gr.Checkbox()
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
train_on_prompt = gr.Checkbox()
|
||||||
|
mask_history = gr.Checkbox()
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
resize_vocab = gr.Checkbox()
|
resize_vocab = gr.Checkbox()
|
||||||
use_llama_pro = gr.Checkbox()
|
use_llama_pro = gr.Checkbox()
|
||||||
@ -116,6 +119,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
optim,
|
optim,
|
||||||
packing,
|
packing,
|
||||||
neat_packing,
|
neat_packing,
|
||||||
|
train_on_prompt,
|
||||||
|
mask_history,
|
||||||
resize_vocab,
|
resize_vocab,
|
||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
shift_attn,
|
shift_attn,
|
||||||
@ -132,6 +137,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
optim=optim,
|
optim=optim,
|
||||||
packing=packing,
|
packing=packing,
|
||||||
neat_packing=neat_packing,
|
neat_packing=neat_packing,
|
||||||
|
train_on_prompt=train_on_prompt,
|
||||||
|
mask_history=mask_history,
|
||||||
resize_vocab=resize_vocab,
|
resize_vocab=resize_vocab,
|
||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
shift_attn=shift_attn,
|
shift_attn=shift_attn,
|
||||||
|
@ -32,7 +32,7 @@ if is_gradio_available():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||||
|
|
||||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||||
@ -67,7 +67,7 @@ def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
|||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
def create_web_demo() -> gr.Blocks:
|
def create_web_demo() -> "gr.Blocks":
|
||||||
engine = Engine(pure_chat=True)
|
engine = Engine(pure_chat=True)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
|
@ -522,6 +522,34 @@ LOCALES = {
|
|||||||
"info": "避免打包后的序列产生交叉注意力。",
|
"info": "避免打包后的序列产生交叉注意力。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"train_on_prompt": {
|
||||||
|
"en": {
|
||||||
|
"label": "Train on prompt",
|
||||||
|
"info": "Disable the label mask on the prompt (only for SFT).",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Тренировка на подсказке",
|
||||||
|
"info": "Отключить маску меток на подсказке (только для SFT).",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "学习提示词",
|
||||||
|
"info": "不在提示词的部分添加掩码(仅适用于 SFT)。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"mask_history": {
|
||||||
|
"en": {
|
||||||
|
"label": "Mask history",
|
||||||
|
"info": "Train on the last turn only (only for SFT).",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "История масок",
|
||||||
|
"info": "Тренироваться только на последнем шаге (только для SFT).",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "不学习历史对话",
|
||||||
|
"info": "仅学习最后一轮对话(仅适用于 SFT)。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"resize_vocab": {
|
"resize_vocab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Resize token embeddings",
|
"label": "Resize token embeddings",
|
||||||
@ -536,20 +564,6 @@ LOCALES = {
|
|||||||
"info": "更改分词器词表和嵌入层的大小。",
|
"info": "更改分词器词表和嵌入层的大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"train_last_turn_only": {
|
|
||||||
"en": {
|
|
||||||
"label": "Train last turn only",
|
|
||||||
"info": "Train the model with the last turn only in multi turn.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Обучать только последний поворот",
|
|
||||||
"info": "Обучать модель только последним поворотом в многоповоротном диалоге.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "仅最后一轮参与训练",
|
|
||||||
"info": "多轮对话仅使用最后一轮计算loss。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"use_llama_pro": {
|
"use_llama_pro": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Enable LLaMA Pro",
|
"label": "Enable LLaMA Pro",
|
||||||
|
@ -125,7 +125,6 @@ class Runner:
|
|||||||
visual_inputs=get("top.visual_inputs"),
|
visual_inputs=get("top.visual_inputs"),
|
||||||
dataset_dir=get("train.dataset_dir"),
|
dataset_dir=get("train.dataset_dir"),
|
||||||
dataset=",".join(get("train.dataset")),
|
dataset=",".join(get("train.dataset")),
|
||||||
train_last_turn_only=get("train.train_last_turn_only"),
|
|
||||||
cutoff_len=get("train.cutoff_len"),
|
cutoff_len=get("train.cutoff_len"),
|
||||||
learning_rate=float(get("train.learning_rate")),
|
learning_rate=float(get("train.learning_rate")),
|
||||||
num_train_epochs=float(get("train.num_train_epochs")),
|
num_train_epochs=float(get("train.num_train_epochs")),
|
||||||
@ -141,6 +140,8 @@ class Runner:
|
|||||||
optim=get("train.optim"),
|
optim=get("train.optim"),
|
||||||
packing=get("train.packing") or get("train.neat_packing"),
|
packing=get("train.packing") or get("train.neat_packing"),
|
||||||
neat_packing=get("train.neat_packing"),
|
neat_packing=get("train.neat_packing"),
|
||||||
|
train_on_prompt=get("train.train_on_prompt"),
|
||||||
|
mask_history=get("train.mask_history"),
|
||||||
resize_vocab=get("train.resize_vocab"),
|
resize_vocab=get("train.resize_vocab"),
|
||||||
use_llama_pro=get("train.use_llama_pro"),
|
use_llama_pro=get("train.use_llama_pro"),
|
||||||
shift_attn=get("train.shift_attn"),
|
shift_attn=get("train.shift_attn"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user