From 249adacc4d6af207d6eef20fd1a814b71a34ac9a Mon Sep 17 00:00:00 2001 From: Shiyu Zhang <328574108@qq.com> Date: Thu, 18 Jul 2024 15:30:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=85=E4=BB=85=E8=AE=AD=E7=BB=83=E6=9C=80?= =?UTF-8?q?=E5=90=8E=E4=B8=80=E8=BD=AE=E5=AF=B9=E8=AF=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Former-commit-id: 1e7b396ff2489055574fd3365425d26360d73897 --- src/llamafactory/data/processors/supervised.py | 6 +++++- src/llamafactory/hparams/data_args.py | 4 ++++ src/llamafactory/hparams/parser.py | 3 +++ src/llamafactory/webui/components/train.py | 5 +++-- src/llamafactory/webui/locales.py | 14 ++++++++++++++ src/llamafactory/webui/runner.py | 1 + 6 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 141054f4..169b31d3 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -70,7 +70,11 @@ def _encode_supervised_example( source_mask = [IGNORE_INDEX] * source_len input_ids += source_ids + target_ids - labels += source_mask + target_ids + + if data_args.train_last_turn_only and turn_idx != len(encoded_pairs) - 1: + labels += source_mask + [IGNORE_INDEX] * len(target_ids) + else: + labels += source_mask + target_ids if template.efficient_eos: input_ids += [tokenizer.eos_token_id] diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 9ae15d2d..10630019 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -41,6 +41,10 @@ class DataArguments: default="data", metadata={"help": "Path to the folder containing the datasets."}, ) + train_last_turn_only: Optional[bool] = field( + default=False, + metadata={"help": "Whether or not to train the last turn only."}, + ) cutoff_len: int = field( default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index b3c87b76..65e26e6a 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -162,6 +162,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Check arguments if finetuning_args.stage != "pt" and data_args.template is None: raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage == "pt" and data_args.train_last_turn_only: + raise ValueError("PT stage does not support `train_last_turn_only`.") if finetuning_args.stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True except SFT.") diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 9f7e0d2a..30a929c3 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -44,10 +44,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=1) dataset = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=4) + train_last_turn_only = gr.Checkbox() preview_elems = create_preview_box(dataset_dir, dataset) - input_elems.update({training_stage, dataset_dir, dataset}) - elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) + input_elems.update({training_stage, dataset_dir, dataset,train_last_turn_only}) + elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset,train_last_turn_only=train_last_turn_only, **preview_elems)) with gr.Row(): learning_rate = gr.Textbox(value="5e-5") diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index affc832f..2211a37f 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -536,6 +536,20 @@ LOCALES = { "info": "更改分词器词表和嵌入层的大小。", }, }, + "train_last_turn_only": { + "en": { + "label": "Train last turn only", + "info": "Train the model with the last turn only in multi turn.", + }, + "ru": { + "label": "Обучать только последний поворот", + "info": "Обучать модель только последним поворотом в многоповоротном диалоге.", + }, + "zh": { + "label": "仅最后一轮参与训练", + "info": "多轮对话仅使用最后一轮计算loss。", + }, + }, "use_llama_pro": { "en": { "label": "Enable LLaMA Pro", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 68edd48b..6a766abe 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -125,6 +125,7 @@ class Runner: visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), + train_last_turn_only=get("train.train_last_turn_only"), cutoff_len=get("train.cutoff_len"), learning_rate=float(get("train.learning_rate")), num_train_epochs=float(get("train.num_train_epochs")),